Skip to content

[AllReduce] FlashInfer: add mnnvl backend selection and standalone TP path#19586

Draft
mmangkad wants to merge 4 commits intosgl-project:mainfrom
mmangkad-dev:fi-ar-mnnvl
Draft

[AllReduce] FlashInfer: add mnnvl backend selection and standalone TP path#19586
mmangkad wants to merge 4 commits intosgl-project:mainfrom
mmangkad-dev:fi-ar-mnnvl

Conversation

@mmangkad
Copy link
Copy Markdown
Contributor

Motivation

This PR adds support for backend selection (auto/mnnvl/trtllm) for FlashInfer allreduce and introduces standalone FlashInfer allreduce for non-fused TP allreduce. mnnvl is optimized for multi-node NVLink setups and can be more beneficial in that setting.

Modifications

  • Added standalone FlashInfer allreduce for non-fused TP allreduce (kAllReduce path).
  • Added --enable-flashinfer-allreduce (default: False) to enable standalone FlashInfer allreduce.
  • Added --flashinfer-allreduce-backend {auto,trtllm,mnnvl} (default: auto; with current FlashInfer, auto resolves to mnnvl) for both fused and standalone FlashInfer allreduce.

Case Matrix (H100, tp=2)

Case Server command
mnnvl_fusion_custom_ar sglang serve --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --tensor-parallel-size 2 --enable-flashinfer-allreduce-fusion --flashinfer-allreduce-backend mnnvl
trtllm_fusion_custom_ar sglang serve --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --tensor-parallel-size 2 --enable-flashinfer-allreduce-fusion --flashinfer-allreduce-backend trtllm
mnnvl_fusion_fi_ar_only sglang serve --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --tensor-parallel-size 2 --enable-flashinfer-allreduce-fusion --enable-flashinfer-allreduce --disable-custom-all-reduce --flashinfer-allreduce-backend mnnvl
trtllm_fusion_fi_ar_only sglang serve --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --tensor-parallel-size 2 --enable-flashinfer-allreduce-fusion --enable-flashinfer-allreduce --disable-custom-all-reduce --flashinfer-allreduce-backend trtllm

Accuracy Tests (Full GSM8K, 8-shot)

Case Accuracy Invalid
mnnvl_fusion_custom_ar 0.941 0.000
trtllm_fusion_custom_ar 0.942 0.000
mnnvl_fusion_fi_ar_only 0.941 0.000
trtllm_fusion_fi_ar_only 0.942 0.000

No accuracy regression observed.

Benchmarking

Case Output tok/s (bs=1) Mean TPOT (bs=1) Output tok/s (bs=16) Mean TPOT (bs=16)
mnnvl_fusion_custom_ar 237.62 4.09 1666.91 8.64
trtllm_fusion_custom_ar 237.18 4.10 1660.79 8.71
mnnvl_fusion_fi_ar_only 237.83 4.09 1689.28 8.57
trtllm_fusion_fi_ar_only 236.13 4.12 1686.32 8.59

No TPOT regression observed. mnnvl is on par with trtllm and can be better.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 28, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the distributed communication capabilities by integrating advanced FlashInfer allreduce features. It provides users with greater control over the allreduce backend, including a new mnnvl option optimized for multi-node NVLink, and introduces a standalone FlashInfer allreduce path for non-fused tensor parallelism. These additions aim to improve performance and adaptability across various hardware environments without compromising accuracy or throughput.

Highlights

  • FlashInfer AllReduce Backend Selection: Introduced the ability to select between 'auto', 'trtllm', and 'mnnvl' backends for FlashInfer allreduce operations, allowing for optimized performance in different hardware configurations, especially multi-node NVLink setups with 'mnnvl'.
  • Standalone FlashInfer AllReduce: Added support for a standalone FlashInfer allreduce path for non-fused Tensor Parallel (TP) allreduce, providing more flexibility and potential performance gains beyond fused operations.
  • New Configuration Options: New command-line arguments --enable-flashinfer-allreduce and --flashinfer-allreduce-backend were added to enable and configure the standalone FlashInfer allreduce functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • docs/advanced_features/server_arguments.md
    • Added documentation for the new --enable-flashinfer-allreduce and --flashinfer-allreduce-backend server arguments.
  • python/sglang/srt/distributed/device_communicators/flashinfer_all_reduce.py
    • Added a new FlashInferAllReduce class to manage standalone FlashInfer allreduce operations, including workspace initialization, buffer size checks, and allreduce execution.
    • Implemented logic to disable FlashInfer allreduce if FlashInfer communication API is unavailable or CUDA is not present.
    • Defined GPU capability-based maximum tensor sizes for allreduce operations, adopted from vLLM thresholds.
  • python/sglang/srt/distributed/device_communicators/flashinfer_utils.py
    • Added a new file to provide utility functions for FlashInfer, specifically create_mnnvl_comm_backend to bridge FlashInfer's mnnvl backend with torch.distributed process groups.
    • Included a placeholder CommBackend class for when FlashInfer is not available.
  • python/sglang/srt/distributed/parallel_state.py
    • Added use_flashinfer_allreduce and flashinfer_allreduce_backend attributes to the GroupCoordinator class.
    • Introduced fi_ar_comm (FlashInfer allreduce communicator) to GroupCoordinator.
    • Modified the GroupCoordinator constructor to accept and initialize FlashInfer allreduce parameters.
    • Updated the all_reduce method to prioritize FlashInfer allreduce if enabled and applicable.
    • Adjusted the _all_reduce_out_place method to dispatch to the FlashInfer allreduce communicator.
    • Ensured proper destruction of the FlashInfer allreduce communicator in the destroy method.
    • Updated init_world_group and init_model_parallel_group functions to support the new FlashInfer allreduce configuration parameters.
    • Added global flags _ENABLE_FLASHINFER_ALL_REDUCE and _FLASHINFER_ALLREDUCE_BACKEND along with their setter functions.
  • python/sglang/srt/layers/flashinfer_comm_fusion.py
    • Imported get_tp_group and create_mnnvl_comm_backend for enhanced distributed communication.
    • Added helper functions _get_flashinfer_allreduce_backend and _create_mnnvl_comm_backend to retrieve backend settings and create mnnvl communicators.
    • Modified FlashInferWorkspaceManager to incorporate backend selection (auto/mnnvl/trtllm) during workspace initialization.
    • Improved workspace management logic, including checks for mnnvl backend initialization and more robust buffer size sufficiency checks to prevent repeated reinitialization.
    • Updated ensure_workspace_initialized to pass the selected backend to the workspace manager.
  • python/sglang/srt/model_executor/model_runner.py
    • Imported new setter functions for FlashInfer allreduce configuration.
    • Updated the init_torch_distributed method to configure FlashInfer allreduce based on server arguments, enabling or disabling it and setting the chosen backend.
  • python/sglang/srt/server_args.py
    • Defined FLASHINFER_ALLREDUCE_BACKEND_CHOICES for valid backend options.
    • Added enable_flashinfer_allreduce (boolean) and flashinfer_allreduce_backend (literal type) to the ServerArgs class.
    • Included corresponding command-line arguments in the add_cli_args function for user configuration.
Activity
  • The pull request author, mmangkad, initiated this change to introduce FlashInfer allreduce backend selection and a standalone TP path.
  • The author provided detailed motivation, modifications, case matrix, accuracy tests, and benchmarking results in the pull request description.
  • Accuracy tests showed no regression, and benchmarking indicated that the mnnvl backend is on par with or potentially better than trtllm.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the mnnvl backend for FlashInfer all-reduce operations and adds a standalone path for non-fused tensor parallelism all-reduce. The changes are well-structured, adding new arguments, documentation, and the necessary logic in the communication and parallel state management modules. My review focuses on improving error handling, ensuring the correctness of a fallback implementation, and enhancing user-facing documentation for the new options. Overall, this is a solid contribution to extend the all-reduce capabilities.

@mmangkad
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 1, 2026

MNNVL only supports SM100, which is a hard limitation for a general solution.

@mmangkad
Copy link
Copy Markdown
Contributor Author

mmangkad commented Mar 1, 2026

@yuan-luo
image

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 1, 2026

Maybe FlashInfer has upgrade recently, in Nov last year, it didn't support hopper.
#12787

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants