Skip to content

[Feature] Integrate Elastic NIXL-EP into SGLang#19248

Merged
ShangmingCai merged 13 commits intosgl-project:mainfrom
zackyoray:elastic_nixl_ep_moe
Mar 11, 2026
Merged

[Feature] Integrate Elastic NIXL-EP into SGLang#19248
ShangmingCai merged 13 commits intosgl-project:mainfrom
zackyoray:elastic_nixl_ep_moe

Conversation

@zackyoray
Copy link
Contributor

Overview

This PR introduces support for the NIXL-EP MoE backend in SGLang, enabling efficient expert parallelism through NVIDIA's NIXL framework. This implementation leverages the elastic expert parallelism infrastructure being developed as part of the Elastic EP Support roadmap (PR #8961).

What is NIXL-EP?

NIXL-EP is a complete implementation of expert-parallel communication for Mixture of Experts (MoE) models built on top of NIXL's device API. It provides elastic scaling and fault tolerance, enabling dynamic addition and removal of processes (ranks) during runtime without disrupting existing connections, and leverages NIXL's RDMA and NVLink support for optimal performance.

Testing & Performance

The implementation has been validated with DeepSeek-V3-Lite using the standard python -m sglang.bench_serving benchmark tool.

Test Configuration

Parameter 1 Node 2 Nodes
Model DeepSeek-V3-Lite (FP8) DeepSeek-V3-Lite (FP8)
Tensor Parallelism 8 16
Data Parallelism 8 8
Max Concurrency 256 256
Number of Prompts 4096 4096
Input Length 128 tokens 128 tokens
Output Length 128 tokens 128 tokens
Redundant Experts 24 24
Memory Fraction 0.78 0.78

Performance Results

1 Node (8 GPUs)

Backend TTFT Mean (ms) TTFT Median (ms) E2E Latency Mean (ms) E2E Latency Median (ms) Request Throughput (req/s)
nixl 239.31 195.17 11,416.26 7,201.92 20.03
deepep 251.15 189.95 11,219.45 7,051.53 20.34

2 Nodes (16 GPUs)

Backend TTFT Mean (ms) TTFT Median (ms) E2E Latency Mean (ms) E2E Latency Median (ms) Request Throughput (req/s)
nixl 272.97 221.46 13,153.86 8,380.98 17.61
deepep 269.97 213.27 12,706.45 8,209.62 18.39

Additional testing across different model scales and cluster configurations is ongoing.

Related Work


Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zackyoray, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances SGLang's Mixture of Experts (MoE) capabilities by integrating the Elastic NIXL-EP backend. This integration provides a new, high-performance option for expert parallelism, building upon existing elastic EP infrastructure. The changes involve adding NIXL as a recognized communication backend, implementing its specific dispatcher logic, and establishing a robust coordination mechanism for distributed operations. This expands the framework's flexibility and efficiency for large-scale MoE model serving.

Highlights

  • NIXL-EP Backend Integration: Introduced support for Elastic NIXL-EP as a new backend for all-to-all communication in Mixture of Experts (MoE) models, leveraging NVIDIA's NIXL framework for efficient expert parallelism.
  • Global TCPStore for Coordination: Implemented a global TCPStore mechanism to facilitate coordination across distributed ranks, specifically for NIXL buffer setup and other distributed components.
  • Expanded MoE Backend Options: Added 'nixl' as a selectable option for --moe-a2a-backend and --elastic-ep-backend in server arguments and documentation, allowing users to choose NIXL for MoE communication.
  • Elasticity-Aware Hierarchical EPLB Algorithm: Extended the Expert Parallelism Load Balancing (EPLB) algorithms with a new 'elasticity_aware_hierarchical' option, enhancing expert distribution strategies.
  • DeepSeek-V3-Lite Validation: Validated the NIXL-EP implementation with DeepSeek-V3-Lite using standard benchmarks, demonstrating comparable performance to existing backends across single and multi-node configurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • docs/advanced_features/server_arguments.md
    • Added 'nixl' as a valid option for the --moe-a2a-backend argument in the server arguments documentation.
  • python/sglang/srt/batch_overlap/two_batch_overlap.py
    • Imported NixlEPDispatcher for use in batch overlap processing.
    • Added conditional logic to instantiate NixlEPDispatcher when the 'nixl' MoE A2A backend is selected.
  • python/sglang/srt/distributed/parallel_state.py
    • Imported set_global_tcp_store from sglang.srt.distributed.utils.
    • Added a new private function _create_global_tcp_store to establish a global TCPStore for inter-rank coordination.
    • Integrated the call to _create_global_tcp_store into the init_distributed_environment function to ensure TCPStore initialization during distributed setup.
  • python/sglang/srt/distributed/utils.py
    • Introduced global variables and functions (_global_tcp_store, set_global_tcp_store, get_global_tcp_store) to manage a single, shared TCPStore instance across the distributed environment.
  • python/sglang/srt/eplb/eplb_algorithms/init.py
    • Added elasticity_aware_hierarchical as a new member to the EplbAlgorithm enum.
    • Modified the rebalance_experts function to support the new elasticity_aware_hierarchical algorithm, enabling hierarchical rebalancing for elastic EP.
  • python/sglang/srt/layers/moe/ep_moe/layer.py
    • Updated the get_moe_impl_class function to return DeepEPMoE when the 'nixl' A2A backend is active, aligning NIXL with existing DeepEP implementations.
  • python/sglang/srt/layers/moe/fused_moe_triton/layer.py
    • Modified create_moe_dispatcher to instantiate MaybeTboDeepEPDispatcher when 'nixl' is the active A2A backend.
  • python/sglang/srt/layers/moe/token_dispatcher/init.py
    • Imported NixlEPCombineInput, NixlEPDispatcher, and NixlEPDispatchOutput from the new nixl token dispatcher module.
    • Exported the newly imported NIXL-EP related classes to make them accessible within the package.
  • python/sglang/srt/layers/moe/token_dispatcher/nixl.py
    • Added a new file implementing the NixlEPDispatchOutput and NixlEPCombineInput NamedTuples for NIXL-EP data structures.
    • Implemented NixlEPBuffer as a class method to manage and provide the NIXL buffer, including initialization and connection logic using the global TCPStore.
    • Defined _NixlEPDispatcherImplBase and _NixlEPDispatcherImpl to handle the core dispatch and combine operations for NIXL-EP, supporting FP8 and asynchronous finishing.
    • Created the NixlEPDispatcher class, inheriting from BaseDispatcher, to orchestrate the dispatch and combine stages for NIXL-EP, including stage management and error handling for unsupported modes.
  • python/sglang/srt/layers/moe/utils.py
    • Added NIXL as a new member to the MoeA2ABackend enum.
    • Implemented the is_nixl method within MoeA2ABackend to check if the current backend is NIXL.
  • python/sglang/srt/layers/quantization/fp8.py
    • Updated is_deepgemm_moe_runner_backend_enabled to include the 'nixl' A2A backend in its auto-detection logic for DeepGEMM.
  • python/sglang/srt/models/deepseek_v2.py
    • Modified the __init__ method to include 'nixl' in the conditions for setting tp_rank and tp_size for MoE, and for enabling A2A MoE.
  • python/sglang/srt/models/glm4_moe.py
    • Adjusted the __init__ method to incorporate 'nixl' into the conditions for calculating ep_size and enabling A2A MoE.
  • python/sglang/srt/server_args.py
    • Added 'nixl' to the list of valid choices for moe_a2a_backend and elastic_ep_backend.
    • Included logic in _handle_a2a_moe to adjust ep_size and issue a warning when 'nixl' MoE is enabled.
    • Updated the assertion for eplb_algorithm in _handle_elastic_ep to allow 'elasticity_aware_hierarchical' when Elastic EP is active.
    • Updated the help message for --elastic-ep-backend to reflect support for 'nixl'.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces support for the NIXL-EP MoE backend, which is a significant feature addition. The changes are well-structured, adding nixl as a new backend option and integrating it across the serving stack, from server arguments to the MoE token dispatcher. The new NixlEPDispatcher implementation seems to correctly follow the patterns of existing dispatchers. The addition of a global TCP store for coordination is also handled cleanly. My review has a couple of minor suggestions for code cleanup, but overall the implementation looks solid.

)

# Create a global TCPStore for coordination (used by NIXL)
_create_global_tcp_store(rank, world_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should call this only when nixl-ep is enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, added an if before.

from sglang.srt.compilation.compilation_config import register_split_op
from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
from sglang.srt.distributed.utils import set_global_tcp_store
from sglang.srt.environ import envs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI

Comment on lines +175 to +176
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
"SGLANG_NIXL_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
):
use_fp8 = not get_bool_env_var("SGLANG_NIXL_EP_BF16_DISPATCH")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the comments. Others LGTM. Please ping @ch-wan for review and another approval. He is the core maintainer of the EP module.

BBiber1 and others added 5 commits February 25, 2026 02:36
Signed-off-by: Barak Biber <bbiber@nvidia.com>
Signed-off-by: Yoray Zack <yorayz@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Yoray Zack <yorayz@nvidia.com>
@zackyoray
Copy link
Contributor Author

Please fix the comments. Others LGTM. Please ping @ch-wan for review and another approval. He is the core maintainer of the EP module.

Thanks @ShangmingCai, fixed the comments and rebased.

@ch-wan
Copy link
Collaborator

ch-wan commented Feb 25, 2026

qq: is this related to elastic ep as indicated in the title?

@zackyoray
Copy link
Contributor Author

zackyoray commented Feb 25, 2026

qq: is this related to elastic ep as indicated in the title?

Yes, This PR integrates NIXL-EP, an Elastic EP communication library designed with elasticity as its core feature, natively supporting fault tolerance (with rank recovery) and dynamic scale up/down

Copy link
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some ci (or manual) tests for nixl a2a and elastic ep?

@zackyoray zackyoray requested a review from hnyls2002 as a code owner March 1, 2026 15:22
@zackyoray
Copy link
Contributor Author

can we add some ci (or manual) tests for nixl a2a and elastic ep?

@ch-wan first of all thanks for your review,

I added manual tests in test/manual/ep/test_nixl_ep.py covering:

  • NIXL a2a MoE backend (pure TP, DP attention, elastic EP + EPLB)
  • NIXL a2a MoE with Mooncake elastic EP backend for E2E fault tolerance testing

@ShangmingCai
Copy link
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Mar 6, 2026
@ShangmingCai
Copy link
Collaborator

Need to fix lint with pre-commit run --all-files

@zackyoray
Copy link
Contributor Author

Need to fix lint with pre-commit run --all-files

Thanks, fixed that.

@ShangmingCai
Copy link
Collaborator

/rerun-failed-ci

@zackyoray
Copy link
Contributor Author

Thanks @ShangmingCai
seems that the failing CI is not related, am i correct?

@ch-wan, @ShangmingCai what should be the next step for merging this PR? should i expect another round of review?

Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. You can ping @ch-wan for a final check.

Copy link
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please fix the conflict.

@ShangmingCai
Copy link
Collaborator

@ShangmingCai ShangmingCai merged commit 9991deb into sgl-project:main Mar 11, 2026
28 of 45 checks passed
liubiyongge pushed a commit to liubiyongge/sglang that referenced this pull request Mar 13, 2026
Signed-off-by: Barak Biber <bbiber@nvidia.com>
Signed-off-by: Yoray Zack <yorayz@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Barak Biber <bbiber@nvidia.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Signed-off-by: Barak Biber <bbiber@nvidia.com>
Signed-off-by: Yoray Zack <yorayz@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Barak Biber <bbiber@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants