diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 961d45505361..1df5a06f3aea 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -314,7 +314,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `nixl`, `ascend_fuseep`| | `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_trtllm_routed`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` | | `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` | +| `--enable-flashinfer-allreduce` | Enable FlashInfer standalone allreduce for non-fused TP allreduce. | `False` | bool flag (set to enable) | | `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | +| `--flashinfer-allreduce-backend` | Select FlashInfer backend for standalone/fused allreduce. | `auto` | `auto`, `trtllm`, `mnnvl` | | `--enable-aiter-allreduce-fusion` | Enable aiter allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | `auto` | `normal`, `low_latency`, `auto` | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | `0` | Type: int | diff --git a/python/sglang/srt/distributed/device_communicators/flashinfer_all_reduce.py b/python/sglang/srt/distributed/device_communicators/flashinfer_all_reduce.py new file mode 100644 index 000000000000..a727f4fb89b7 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/flashinfer_all_reduce.py @@ -0,0 +1,206 @@ +import logging +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.srt.distributed.device_communicators.flashinfer_utils import ( + create_mnnvl_comm_backend, +) + +logger = logging.getLogger(__name__) + +_flashinfer_comm = None +_flashinfer_ar_available = False +try: + import flashinfer.comm as flashinfer_comm + + if hasattr(flashinfer_comm, "allreduce_fusion") and hasattr( + flashinfer_comm, "create_allreduce_fusion_workspace" + ): + _flashinfer_comm = flashinfer_comm + _flashinfer_ar_available = True +except ImportError: + pass + +MiB = 1024 * 1024 + +# Max size of the communicated tensor by world size and GPU capability. +# Adopted from vLLM thresholds. +# TODO(mmangkad): Tune these thresholds for SGLang, since optimal values may +# differ from vLLM based on runtime/scheduling behavior. +_FI_ALLREDUCE_MAX_SIZE_MB: dict[int, dict[int, float]] = { + 90: { + 2: 64, + 4: 2, + 8: 0.5, + }, + 100: { + 2: 64, + 4: 32, + 8: 1, + }, +} + + +def _get_device_capability() -> Optional[int]: + if not torch.cuda.is_available(): + return None + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +class FlashInferAllReduce: + def __init__( + self, + group: ProcessGroup, + device: torch.device, + backend: str = "auto", + ): + self.disabled = True + self.workspace = None + self.max_num_tokens = 0 + self.max_workspace_size = None + self.hidden_dim = None + self.dtype = None + + if not _flashinfer_ar_available or _flashinfer_comm is None: + logger.info( + "FlashInfer allreduce disabled: flashinfer comm API unavailable." + ) + return + + if not torch.cuda.is_available(): + logger.info("FlashInfer allreduce disabled: CUDA is unavailable.") + return + + self.group = group + self.world_size = dist.get_world_size(group=self.group) + self.rank = dist.get_rank(group=self.group) + self.device = device + self.backend = backend + + if self.world_size == 1: + return + + capability = _get_device_capability() + self.max_workspace_size = _FI_ALLREDUCE_MAX_SIZE_MB.get(capability, {}).get( + self.world_size + ) + if self.max_workspace_size is None: + logger.warning( + "FlashInfer allreduce disabled: unsupported world_size=%d for SM=%s.", + self.world_size, + str(capability), + ) + return + + self.max_workspace_size = int(self.max_workspace_size * MiB) + self.disabled = False + + def _create_workspace( + self, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + ) -> bool: + assert _flashinfer_comm is not None + + workspace_kwargs = dict( + backend=self.backend, + world_size=self.world_size, + rank=self.rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + ) + + if self.backend in ("auto", "mnnvl"): + comm_backend = create_mnnvl_comm_backend(self.group) + if comm_backend is not None: + workspace_kwargs["comm_backend"] = comm_backend + + self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( + **workspace_kwargs + ) + self.hidden_dim = hidden_dim + self.dtype = dtype + return self.workspace is not None + + def _ensure_workspace( + self, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + ) -> bool: + if self.workspace is not None: + if self.hidden_dim == hidden_dim and self.dtype == dtype: + try: + if self.workspace.is_buffer_size_sufficient( + tp_size=self.world_size, + num_tokens=num_tokens, + hidden_dim=hidden_dim, + dtype=dtype, + ): + return True + except Exception as e: + logger.debug( + "FlashInfer workspace size check failed; recreating workspace: %s", + e, + ) + self.destroy() + + assert self.max_workspace_size is not None + element_size = torch.tensor([], dtype=dtype, device="cpu").element_size() + max_tokens = self.max_workspace_size // (hidden_dim * element_size) + if max_tokens <= 0 or num_tokens > max_tokens: + return False + + self.max_num_tokens = max_tokens + try: + return self._create_workspace( + max_token_num=max_tokens, + hidden_dim=hidden_dim, + dtype=dtype, + ) + except Exception as e: + logger.warning( + "Failed to initialize FlashInfer allreduce workspace: %s. " + "Disabling FlashInfer allreduce.", + e, + ) + self.disabled = True + self.workspace = None + return False + + def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool: + if self.disabled: + return False + + if not input_tensor.is_cuda or not input_tensor.is_contiguous(): + return False + + if len(input_tensor.shape) != 2: + return False + + num_tokens, hidden_dim = input_tensor.shape + return self._ensure_workspace(num_tokens, hidden_dim, input_tensor.dtype) + + def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor: + assert _flashinfer_comm is not None + return _flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=self.workspace, + pattern=_flashinfer_comm.AllReduceFusionPattern.kAllReduce, + ) + + def destroy(self): + if self.workspace is not None: + try: + self.workspace.destroy() + except Exception as e: + logger.debug("Failed to destroy FlashInfer workspace: %s", e) + self.workspace = None + self.hidden_dim = None + self.dtype = None diff --git a/python/sglang/srt/distributed/device_communicators/flashinfer_utils.py b/python/sglang/srt/distributed/device_communicators/flashinfer_utils.py new file mode 100644 index 000000000000..07b4fb3ffc63 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/flashinfer_utils.py @@ -0,0 +1,53 @@ +import torch.distributed as dist + +from sglang.srt.utils import is_flashinfer_available + +if is_flashinfer_available(): + try: + from flashinfer.comm.mnnvl import CommBackend + except ImportError: + CommBackend = object # type: ignore[assignment,misc] +else: + + class CommBackend: + """Placeholder base class when flashinfer is not available.""" + + pass + + +def create_mnnvl_comm_backend(group: dist.ProcessGroup): + """Create a mnnvl comm backend backed by torch.distributed process group.""" + try: + from flashinfer.comm.mnnvl import TorchDistBackend + + return TorchDistBackend(group=group) + except Exception: + pass + + class TorchDistributedCommBackend(CommBackend): + def __init__(self, group_: dist.ProcessGroup): + self._group = group_ + + def Get_rank(self) -> int: + return self._group.rank() + + def Get_size(self) -> int: + return self._group.size() + + def allgather(self, data: int): + gathered = [None] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + + def bcast(self, data, root: int = 0): + obj_list = [data] + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] + + def Split(self, color: int, key: int): + return self + + def barrier(self): + dist.barrier(group=self._group) + + return TorchDistributedCommBackend(group) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 72311f9d3ffe..4676693684b8 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -220,6 +220,7 @@ class GroupCoordinator: use_pynccl: bool # a hint of whether to use PyNccl use_pymscclpp: bool # a hint of whether to use PyMsccl use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + use_flashinfer_allreduce: bool # a hint of whether to use FlashInfer allreduce use_torch_symm_mem_all_reduce: ( bool # a hint of whether to use TorchSymmMemAllReduce ) @@ -229,6 +230,7 @@ class GroupCoordinator: # communicators are only created for world size > 1 pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator + fi_ar_comm: Optional[Any] # FlashInfer allreduce communicator torch_symm_mem_comm: Optional[Any] # Torch symm mem communicator mq_broadcaster: Optional[Any] # shared memory broadcaster @@ -240,6 +242,8 @@ def __init__( use_pynccl: bool, use_pymscclpp: bool, use_custom_allreduce: bool, + use_flashinfer_allreduce: bool, + flashinfer_allreduce_backend: str, use_torch_symm_mem_all_reduce: bool, use_hpu_communicator: bool, use_xpu_communicator: bool, @@ -319,6 +323,8 @@ def __init__( self.pynccl_use_current_stream = pynccl_use_current_stream self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce + self.use_flashinfer_allreduce = use_flashinfer_allreduce + self.flashinfer_allreduce_backend = flashinfer_allreduce_backend self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator @@ -329,6 +335,9 @@ def __init__( from sglang.srt.distributed.device_communicators.custom_all_reduce import ( dispatch_custom_allreduce, ) + from sglang.srt.distributed.device_communicators.flashinfer_all_reduce import ( + FlashInferAllReduce, + ) from sglang.srt.distributed.device_communicators.pymscclpp import ( PyMscclppCommunicator, ) @@ -368,6 +377,19 @@ def __init__( device=self.device, ) + self.fi_ar_comm: Optional[FlashInferAllReduce] = None + # Standalone FlashInfer allreduce is currently only intended for TP-like groups. + if ( + use_flashinfer_allreduce + and self.world_size > 1 + and "tp" in self.unique_name + ): + self.fi_ar_comm = FlashInferAllReduce( + group=self.cpu_group, + device=self.device, + backend=flashinfer_allreduce_backend, + ) + self.ca_comm: Optional[Any] = None self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: @@ -610,17 +632,23 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: outplace_all_reduce_method = None if ( - self.ca_comm is not None - and not self.ca_comm.disabled - and self.ca_comm.should_custom_ar(input_) - ): - outplace_all_reduce_method = "ca" - elif ( self.qr_comm is not None and not self.qr_comm.disabled and self.qr_comm.should_quick_allreduce(input_) ): outplace_all_reduce_method = "qr" + elif ( + self.fi_ar_comm is not None + and not self.fi_ar_comm.disabled + and self.fi_ar_comm.should_use_fi_ar(input_) + ): + outplace_all_reduce_method = "flashinfer" + elif ( + self.ca_comm is not None + and not self.ca_comm.disabled + and self.ca_comm.should_custom_ar(input_) + ): + outplace_all_reduce_method = "ca" elif ( self.pymscclpp_comm is not None and not self.pymscclpp_comm.disabled @@ -703,16 +731,29 @@ def _all_reduce_out_place( ) -> torch.Tensor: ca_comm = self.ca_comm qr_comm = self.qr_comm + fi_ar_comm = self.fi_ar_comm pymscclpp_comm = self.pymscclpp_comm torch_symm_mem_comm = self.torch_symm_mem_comm pynccl_comm = self.pynccl_comm - assert any([qr_comm, ca_comm, pymscclpp_comm, torch_symm_mem_comm, pynccl_comm]) + assert any( + [ + qr_comm, + fi_ar_comm, + ca_comm, + pymscclpp_comm, + torch_symm_mem_comm, + pynccl_comm, + ] + ) if outplace_all_reduce_method == "ca": assert not ca_comm.disabled out = ca_comm.custom_all_reduce(input_) elif outplace_all_reduce_method == "qr": assert not qr_comm.disabled out = qr_comm.quick_all_reduce(input_) + elif outplace_all_reduce_method == "flashinfer": + assert not fi_ar_comm.disabled + out = fi_ar_comm.all_reduce(input_) elif outplace_all_reduce_method == "torch_symm_mem": assert not torch_symm_mem_comm.disabled out = torch_symm_mem_comm.all_reduce(input_) @@ -1398,6 +1439,9 @@ def destroy(self): self.cpu_group = None if self.pynccl_comm is not None: self.pynccl_comm = None + if self.fi_ar_comm is not None: + self.fi_ar_comm.destroy() + self.fi_ar_comm = None if self.ca_comm is not None: self.ca_comm = None if self.mq_broadcaster is not None: @@ -1422,6 +1466,8 @@ def init_world_group( use_pynccl=False, use_pymscclpp=False, use_custom_allreduce=False, + use_flashinfer_allreduce=False, + flashinfer_allreduce_backend="auto", use_torch_symm_mem_all_reduce=False, use_hpu_communicator=False, use_xpu_communicator=False, @@ -1441,6 +1487,8 @@ def init_model_parallel_group( use_mscclpp_allreduce: Optional[bool] = None, pynccl_use_current_stream: bool = True, use_torch_symm_mem_allreduce: Optional[bool] = None, + use_flashinfer_allreduce: Optional[bool] = None, + flashinfer_allreduce_backend: Optional[str] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -1448,6 +1496,10 @@ def init_model_parallel_group( use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE if use_torch_symm_mem_allreduce is None: use_torch_symm_mem_allreduce = _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE + if use_flashinfer_allreduce is None: + use_flashinfer_allreduce = _ENABLE_FLASHINFER_ALL_REDUCE + if flashinfer_allreduce_backend is None: + flashinfer_allreduce_backend = _FLASHINFER_ALLREDUCE_BACKEND return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, @@ -1459,6 +1511,8 @@ def init_model_parallel_group( ), use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, + use_flashinfer_allreduce=use_flashinfer_allreduce, + flashinfer_allreduce_backend=flashinfer_allreduce_backend, use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce, use_hpu_communicator=True, use_xpu_communicator=True, @@ -1586,6 +1640,8 @@ def graph_capture(stream: Optional[torch.cuda.Stream] = None): _ENABLE_CUSTOM_ALL_REDUCE = True _ENABLE_MSCCLPP_ALL_REDUCE = False _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = False +_ENABLE_FLASHINFER_ALL_REDUCE = False +_FLASHINFER_ALLREDUCE_BACKEND = "auto" def set_custom_all_reduce(enable: bool): @@ -1603,6 +1659,16 @@ def set_torch_symm_mem_all_reduce(enable: bool): _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = enable +def set_flashinfer_all_reduce(enable: bool): + global _ENABLE_FLASHINFER_ALL_REDUCE + _ENABLE_FLASHINFER_ALL_REDUCE = enable + + +def set_flashinfer_all_reduce_backend(backend: str): + global _FLASHINFER_ALLREDUCE_BACKEND + _FLASHINFER_ALLREDUCE_BACKEND = backend + + _DEVICE_TO_DISTRIBUTED_BACKEND = { "cuda": "nccl", "xpu": "xccl", diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index 04a6ababc56a..5e0069b3bc4b 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -10,6 +10,10 @@ get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, + get_tp_group, +) +from sglang.srt.distributed.device_communicators.flashinfer_utils import ( + create_mnnvl_comm_backend, ) from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils.custom_op import register_custom_op @@ -43,7 +47,29 @@ def is_flashinfer_allreduce_unavailable() -> bool: - return _flashinfer_allreduce_unavailable + return _flashinfer_allreduce_unavailable or ( + _workspace_manager is not None + and getattr(_workspace_manager, "disabled", False) + ) + + +def _get_flashinfer_allreduce_backend() -> str: + try: + from sglang.srt.server_args import get_global_server_args + + return get_global_server_args().flashinfer_allreduce_backend + except Exception: + return "auto" + + +def _create_mnnvl_comm_backend(): + try: + tp_group = get_tp_group().cpu_group + except Exception as e: + logger.debug(f"Failed to fetch TP process group for mnnvl backend: {e}") + return None + + return create_mnnvl_comm_backend(tp_group) class FlashInferWorkspaceManager: @@ -54,10 +80,15 @@ def __init__(self): self.max_token_num = None self.hidden_dim = None self.dtype = None + self.backend = None + self.resolved_backend = None + self.force_oneshot_support = False self.initialized = False + self.disabled = False def initialize( self, + backend: str, world_size: int, rank: int, max_token_num: int, @@ -66,6 +97,9 @@ def initialize( use_oneshot: Optional[bool] = None, ): """Initialize workspace""" + if self.disabled: + return + if _flashinfer_comm is None: logger.warning( "FlashInfer comm not available, skipping workspace initialization" @@ -74,8 +108,8 @@ def initialize( self.cleanup() try: - self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", + workspace_kwargs = dict( + backend=backend, world_size=world_size, rank=rank, max_token_num=max_token_num, @@ -83,15 +117,29 @@ def initialize( dtype=dtype, force_oneshot_support=bool(use_oneshot), ) + + if backend in ("auto", "mnnvl"): + comm_backend = _create_mnnvl_comm_backend() + if comm_backend is not None: + workspace_kwargs["comm_backend"] = comm_backend + else: + logger.warning( + "Could not initialize mnnvl comm backend from torch.distributed. " + "FlashInfer will fall back to default comm backend behavior." + ) + + self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( + **workspace_kwargs + ) except Exception as e: - global _flashinfer_allreduce_unavailable - _flashinfer_allreduce_unavailable = True logger.warning( - f"Failed to initialize FlashInfer workspace: {e}. " - "Disabling flashinfer allreduce fusion permanently." + "Failed to initialize FlashInfer workspace: %s. " + "Disabling FlashInfer allreduce fusion for this run.", + e, ) self.workspace = None self.initialized = False + self.disabled = True return self.world_size = world_size @@ -99,12 +147,16 @@ def initialize( self.max_token_num = max_token_num self.hidden_dim = hidden_dim self.dtype = dtype + self.backend = backend + self.resolved_backend = getattr(self.workspace, "backend", backend) + self.force_oneshot_support = bool(use_oneshot) self.initialized = True + self.disabled = False - backend = getattr(self.workspace, "backend", "unknown") logger.info( f"FlashInfer workspace initialized for rank {rank}, " - f"world_size {world_size}, backend {backend}" + f"world_size {world_size}, backend {self.resolved_backend} " + f"(requested: {backend})" ) def is_buffer_size_sufficient( @@ -116,17 +168,32 @@ def is_buffer_size_sufficient( ) -> bool: if not self.initialized or self.workspace is None: return False - try: - return self.workspace.is_buffer_size_sufficient( - tp_size=self.world_size, - num_tokens=token_num, - hidden_dim=hidden_dim, - dtype=dtype, - use_oneshot=use_oneshot, - ) - except Exception as e: - logger.debug(f"FlashInfer workspace size check failed: {e}") + if hidden_dim != self.hidden_dim or dtype != self.dtype: + return False + if token_num > self.max_token_num: return False + if bool(use_oneshot) and not self.force_oneshot_support: + return False + # Avoid expensive/unstable backend checks on mnnvl which can trigger + # repeated workspace churn during graph capture. + if self.resolved_backend == "trtllm" and hasattr( + self.workspace, "is_buffer_size_sufficient" + ): + try: + return self.workspace.is_buffer_size_sufficient( + tp_size=self.world_size, + num_tokens=token_num, + hidden_dim=hidden_dim, + dtype=dtype, + use_oneshot=use_oneshot, + ) + except Exception as e: + logger.debug( + "FlashInfer workspace internal size check failed; keeping current " + "workspace to avoid repeated reinitialization: %s", + e, + ) + return True def cleanup(self): """Clean up workspace""" @@ -143,6 +210,9 @@ def cleanup(self): self.max_token_num = None self.hidden_dim = None self.dtype = None + self.backend = None + self.resolved_backend = None + self.force_oneshot_support = False _workspace_manager = FlashInferWorkspaceManager() @@ -162,6 +232,8 @@ def ensure_workspace_initialized( if not is_flashinfer_available() or _flashinfer_comm is None: return False + if _workspace_manager.disabled: + return False if use_attn_tp_group: world_size = get_attn_tensor_model_parallel_world_size() @@ -181,11 +253,13 @@ def ensure_workspace_initialized( return False token_num = token_num or max_token_num + backend = _get_flashinfer_allreduce_backend() if ( not _workspace_manager.initialized or _workspace_manager.world_size != world_size or _workspace_manager.rank != rank + or _workspace_manager.backend != backend or not _workspace_manager.is_buffer_size_sufficient( token_num=token_num, hidden_dim=hidden_dim, @@ -194,6 +268,7 @@ def ensure_workspace_initialized( ) ): _workspace_manager.initialize( + backend=backend, world_size=world_size, rank=rank, max_token_num=max_token_num, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e5ba805aacdc..1ecbb9bec4f6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -65,6 +65,8 @@ init_distributed_environment, initialize_model_parallel, set_custom_all_reduce, + set_flashinfer_all_reduce, + set_flashinfer_all_reduce_backend, set_mscclpp_all_reduce, set_torch_symm_mem_all_reduce, ) @@ -849,6 +851,8 @@ def init_torch_distributed(self): self.server_args.host or "127.0.0.1", self.dist_port ).to_tcp() set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) + set_flashinfer_all_reduce(self.server_args.enable_flashinfer_allreduce) + set_flashinfer_all_reduce_backend(self.server_args.flashinfer_allreduce_backend) set_mscclpp_all_reduce(self.server_args.enable_mscclpp) set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index decfc3f0e480..386a2dcc7890 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -201,6 +201,8 @@ "flashinfer", ] +FLASHINFER_ALLREDUCE_BACKEND_CHOICES = ["auto", "trtllm", "mnnvl"] + FP8_GEMM_RUNNER_BACKEND_CHOICES = [ "auto", "deep_gemm", @@ -517,7 +519,9 @@ class ServerArgs: ] = "none" moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" + enable_flashinfer_allreduce: bool = False enable_flashinfer_allreduce_fusion: bool = False + flashinfer_allreduce_backend: Literal["auto", "trtllm", "mnnvl"] = "auto" enable_aiter_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" ep_num_redundant_experts: int = 0 @@ -4813,11 +4817,23 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.flashinfer_mxfp4_moe_precision, help="Choose the computation precision of flashinfer mxfp4 moe", ) + parser.add_argument( + "--enable-flashinfer-allreduce", + action="store_true", + help="Enable FlashInfer standalone allreduce for non-fused TP allreduce.", + ) parser.add_argument( "--enable-flashinfer-allreduce-fusion", action="store_true", help="Enable FlashInfer allreduce fusion with Residual RMSNorm.", ) + parser.add_argument( + "--flashinfer-allreduce-backend", + type=str, + choices=FLASHINFER_ALLREDUCE_BACKEND_CHOICES, + default=ServerArgs.flashinfer_allreduce_backend, + help="FlashInfer backend for standalone/fused allreduce. `mnnvl` is optimized for multi-node NVLink setups.", + ) parser.add_argument( "--enable-aiter-allreduce-fusion", action="store_true",