diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 730a2274513a..98de77dce7a8 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -314,7 +314,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `nixl`, `ascend_fuseep`| | `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_trtllm_routed`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` | | `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` | -| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | +| `--flashinfer-allreduce-fusion-backend` | Enable FlashInfer allreduce fusion (fused allreduce + Residual + RMSNorm) and choose backend. When not set, the feature is disabled. Options: `auto` (choose best), `trtllm` (SM90/100, single-node only), `mnnvl` (SM100, single/multi-node). Backend support table (SM100/SM90, single/multi-node) is in `sglang.srt.layers.flashinfer_comm_fusion`. | `None` | `auto`, `trtllm`, `mnnvl` | | `--enable-aiter-allreduce-fusion` | Enable aiter allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | `auto` | `normal`, `low_latency`, `auto` | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | `0` | Type: int | @@ -563,6 +563,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--enable-flashinfer-trtllm-moe` | NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead. | `None` | N/A | | `--enable-triton-kernel-moe` | NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead. | `None` | N/A | | `--enable-flashinfer-mxfp4-moe` | NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead. | `None` | N/A | +| `--enable-flashinfer-allreduce-fusion` | NOTE: --enable-flashinfer-allreduce-fusion is deprecated. Please set `--flashinfer-allreduce-fusion-backend=auto` instead. | `None` | N/A | | `--crash-on-nan` | Crash the server on nan logprobs. | `False` | Type: str | | `--hybrid-kvcache-ratio` | Mix ratio in [0,1] between uniform and hybrid kv buffers (0.0 = pure uniform: swa_size / full_size = 1)(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length) | `None` | Optional[float] | | `--load-watch-interval` | The interval of load watching in seconds. | `0.1` | Type: float | diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index b354d0a04573..5eb69fe053b0 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -100,7 +100,7 @@ def apply_flashinfer_allreduce_fusion(batch_size: int): and batch_size > 0 and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE and not is_dp_attention_enabled() - and get_global_server_args().enable_flashinfer_allreduce_fusion + and get_global_server_args().flashinfer_allreduce_fusion_backend is not None and not is_flashinfer_allreduce_unavailable() ) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index c45ff977fcae..28bdd9fcc12c 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -2,18 +2,28 @@ from typing import Optional, Tuple import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group, ) +from sglang.srt.distributed.parallel_state import in_the_same_node_as +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils.custom_op import register_custom_op logger = logging.getLogger(__name__) +# FlashInfer allreduce fusion: set when flashinfer is available (see block below) _flashinfer_comm = None _workspace_manager = None +_mnnvl_comm_backend = None +_AllReduceFusionPattern = None +_create_allreduce_fusion_workspace = None +_allreduce_fusion = None _flashinfer_allreduce_unavailable = False if is_flashinfer_available(): @@ -37,12 +47,102 @@ "implementation" ) + try: + # Try to import the unified allreduce API + from flashinfer.comm.allreduce import ( + allreduce_fusion, + create_allreduce_fusion_workspace, + ) + + # AllReduceFusionPattern might be in trtllm_ar or allreduce module + try: + from flashinfer.comm.allreduce import AllReduceFusionPattern + except ImportError: + from flashinfer.comm.trtllm_ar import AllReduceFusionPattern + + _AllReduceFusionPattern = AllReduceFusionPattern + _create_allreduce_fusion_workspace = create_allreduce_fusion_workspace + _allreduce_fusion = allreduce_fusion + except ImportError: + # Fall back to legacy API if unified API is not available + _AllReduceFusionPattern = None + _create_allreduce_fusion_workspace = None + _allreduce_fusion = None + logger.warning( + "FlashInfer unified allreduce API not available, using legacy API" + ) + + try: + from flashinfer.comm.mnnvl import CommBackend + + class TorchDistributedCommBackend(CommBackend): + """ + Use torch distributed instead of MPI to set up flashinfer MNNVL workspaces during initialization + """ + + def __init__(self, group: ProcessGroup): + self._group = group + + def Get_rank(self) -> int: + return self._group.rank() + + def Get_size(self) -> int: + return self._group.size() + + def allgather(self, data: int): + gathered = [None] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + + def bcast(self, data, root: int = 0): + """ + Broadcast a picklable Python object from `root` to all ranks. + Uses torch.distributed.broadcast_object_list under the hood. + + Returns the broadcasted object on every rank. + """ + obj_list = [data] + # broadcast_object_list mutates obj_list in-place + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] + + def barrier(self): + """ + Synchronize all ranks in this communicator. + """ + dist.barrier(group=self._group) + + def Split(self, color: int, key: int): + # No need to split, we already use the proper group + return self._group + + _mnnvl_comm_backend = TorchDistributedCommBackend + except ImportError: + _mnnvl_comm_backend = None + + +# FlashInfer allreduce fusion (fused allreduce + Residual + RMSNorm) backend support +# for --flashinfer-allreduce-fusion-backend: +# +# Feature / Framework | SM100 | SM90 | Single Node | Multi-Node | +# --------------------- | ----- | ---- | ----------- | ---------- | +# TRT-LLM AllReduce | Yes | Yes | Yes | No | +# MNNVL AllReduce | Yes | No | Yes | Yes | +# +# With backend "auto": trtllm is used on single-node, mnnvl on single or multi-node (SM100 only). +# Multi-node + Hopper is unsupported (trtllm would be chosen but does not support multi-node). + def is_flashinfer_allreduce_unavailable() -> bool: return _flashinfer_allreduce_unavailable class FlashInferWorkspaceManager: + """ + Workspace manager using FlashInfer's unified allreduce API. + Wraps FlashInfer's create_allreduce_fusion_workspace() for automatic backend selection. + """ + def __init__(self): self.workspace = None self.world_size = None @@ -51,6 +151,10 @@ def __init__(self): self.hidden_dim = None self.dtype = None self.initialized = False + # Max size ever requested (not cleared on cleanup) so we only grow and minimize recreates + self._max_token_num_seen: Optional[int] = None + self._max_hidden_dim_seen: Optional[int] = None + self._logged_init = False def initialize( self, @@ -58,27 +162,86 @@ def initialize( rank: int, max_token_num: int, hidden_dim: int, - dtype: torch.dtype, + backend: str = "auto", + group: Optional[ProcessGroup] = None, + use_fp32_lamport: bool = False, + dtype: Optional[torch.dtype] = None, use_oneshot: Optional[bool] = None, ): - """Initialize workspace""" + """Initialize workspace using FlashInfer's unified API""" + # Track max size ever requested so we can create with at least that (only grow, minimize recreates) + self._max_token_num_seen = max(max_token_num, self._max_token_num_seen or 0) + self._max_hidden_dim_seen = max(hidden_dim, self._max_hidden_dim_seen or 0) + # Reuse existing workspace if it already covers this problem size + if ( + self.initialized + and self.world_size == world_size + and self.is_buffer_size_sufficient( + token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype or torch.bfloat16, + use_oneshot=use_oneshot, + ) + ): + return + # Same world_size but buffer too small: free old workspace before creating new one + if self.initialized and self.world_size == world_size: + self.cleanup() + if _flashinfer_comm is None: logger.warning( "FlashInfer comm not available, skipping workspace initialization" ) return - self.cleanup() + # Determine GPUs per node for MNNVL backend + # FlashInfer will use this to determine topology internally + gpus_per_node = None + if group is not None: + gpus_per_node = sum(in_the_same_node_as(group, source_rank=0)) + + # Create comm backend for MNNVL if needed + comm_backend = None + if _mnnvl_comm_backend is not None and group is not None: + comm_backend = _mnnvl_comm_backend(group) + try: - self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", + # Create with at least the max size we've ever been asked for (only grow, fewer recreates) + alloc_token_num = max(max_token_num, self._max_token_num_seen or 0) + alloc_hidden_dim = max(hidden_dim, self._max_hidden_dim_seen or 0) + # Use FlashInfer's unified API to create workspace + create_kw = dict( + backend=backend, world_size=world_size, rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - force_oneshot_support=bool(use_oneshot), + max_token_num=alloc_token_num, + hidden_dim=alloc_hidden_dim, + dtype=dtype or torch.bfloat16, + gpus_per_node=gpus_per_node, + comm_backend=comm_backend, ) + if use_oneshot is not None: + create_kw["force_oneshot_support"] = bool(use_oneshot) + self.workspace = _create_allreduce_fusion_workspace(**create_kw) + self.world_size = world_size + self.rank = rank + self.max_token_num = alloc_token_num + self.hidden_dim = alloc_hidden_dim + self.dtype = dtype or torch.bfloat16 + self.initialized = True + + backend_name = getattr(self.workspace, "backend", "unknown") + if not self._logged_init: + logger.info( + f"FlashInfer workspace initialized for rank {rank}, " + f"world_size {world_size}, backend: {backend_name}" + ) + self._logged_init = True + else: + logger.debug( + f"FlashInfer workspace re-initialized for rank {rank}, " + f"world_size {world_size}, backend: {backend_name}" + ) except Exception as e: global _flashinfer_allreduce_unavailable _flashinfer_allreduce_unavailable = True @@ -88,20 +251,6 @@ def initialize( ) self.workspace = None self.initialized = False - return - - self.world_size = world_size - self.rank = rank - self.max_token_num = max_token_num - self.hidden_dim = hidden_dim - self.dtype = dtype - self.initialized = True - - backend = getattr(self.workspace, "backend", "unknown") - logger.info( - f"FlashInfer workspace initialized for rank {rank}, " - f"world_size {world_size}, backend {backend}" - ) def is_buffer_size_sufficient( self, @@ -122,13 +271,22 @@ def is_buffer_size_sufficient( ) except Exception as e: logger.debug(f"FlashInfer workspace size check failed: {e}") + # Fallback: some backends (e.g. MNNVL) may use a different API; reuse if within our allocated size + if ( + self.max_token_num is not None + and self.hidden_dim is not None + and token_num <= self.max_token_num + and hidden_dim <= self.hidden_dim + ): + return True return False def cleanup(self): - """Clean up workspace""" + """Clean up workspace.""" if self.workspace is not None: try: - self.workspace.destroy() + if hasattr(self.workspace, "destroy"): + self.workspace.destroy() except Exception as e: logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") finally: @@ -139,6 +297,7 @@ def cleanup(self): self.max_token_num = None self.hidden_dim = None self.dtype = None + self._logged_init = False _workspace_manager = FlashInferWorkspaceManager() @@ -147,7 +306,9 @@ def cleanup(self): def ensure_workspace_initialized( max_token_num: int = 2048, hidden_dim: int = 4096, - dtype: torch.dtype = torch.float16, + use_fp32_lamport: bool = False, + dtype: Optional[torch.dtype] = None, + group: Optional[ProcessGroup] = None, token_num: Optional[int] = None, use_oneshot: Optional[bool] = None, ): @@ -164,6 +325,7 @@ def ensure_workspace_initialized( rank = get_tensor_model_parallel_rank() token_num = token_num or max_token_num + effective_dtype = dtype or torch.bfloat16 if ( not _workspace_manager.initialized @@ -172,16 +334,20 @@ def ensure_workspace_initialized( or not _workspace_manager.is_buffer_size_sufficient( token_num=token_num, hidden_dim=hidden_dim, - dtype=dtype, + dtype=effective_dtype, use_oneshot=use_oneshot, ) ): + backend = get_global_server_args().flashinfer_allreduce_fusion_backend or "auto" _workspace_manager.initialize( world_size=world_size, rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, + backend=backend, + use_fp32_lamport=use_fp32_lamport, dtype=dtype, + group=group, use_oneshot=use_oneshot, ) @@ -218,7 +384,8 @@ def flashinfer_allreduce_residual_rmsnorm( fp32_acc: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Use FlashInfer's fused allreduce + residual + RMS norm operation + Use FlashInfer's unified fused allreduce + residual + RMS norm operation. + Automatically selects between IPC and MNNVL backends based on topology and hardware. Args: input_tensor: Input tensor that needs allreduce @@ -233,7 +400,7 @@ def flashinfer_allreduce_residual_rmsnorm( Returns: Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output) """ - if not is_flashinfer_available() or _flashinfer_comm is None: + if not is_flashinfer_available(): logger.debug( "FlashInfer not available, falling back to standard implementation" ) @@ -253,37 +420,57 @@ def flashinfer_allreduce_residual_rmsnorm( logger.debug("Non-contiguous tensors, skipping FlashInfer allreduce fusion") return None, None + # Get TP group for workspace initialization + try: + group = get_tp_group().cpu_group + except Exception: + group = None + if not ensure_workspace_initialized( max_token_num=max_token_num, hidden_dim=input_tensor.shape[-1], + use_fp32_lamport=(input_tensor.dtype == torch.float32), dtype=input_tensor.dtype, + group=group, token_num=input_tensor.shape[0], use_oneshot=use_oneshot, ): logger.debug("FlashInfer workspace not available") return None, None + if _workspace_manager.workspace is None: + logger.debug("FlashInfer workspace is None") + return None, None + residual_out = torch.empty_like(residual) norm_out = torch.empty_like(input_tensor) - _flashinfer_comm.allreduce_fusion( - input=input_tensor, - workspace=_workspace_manager.workspace, - pattern=_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, - launch_with_pdl=True, - residual_out=residual_out, - norm_out=norm_out, - residual_in=residual, - rms_gamma=weight, - rms_eps=eps, - use_oneshot=use_oneshot, - fp32_acc=fp32_acc, - ) + try: + if _AllReduceFusionPattern is None or _allreduce_fusion is None: + return None, None + + _allreduce_fusion( + input=input_tensor, + workspace=_workspace_manager.workspace, + pattern=_AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=trigger_completion_at_end, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=weight, + rms_eps=eps, + ) + except Exception as e: + logger.warning(f"FlashInfer allreduce fusion failed: {e}") + return None, None return norm_out, residual_out def cleanup_flashinfer_workspace(): + """Clean up FlashInfer workspace""" global _workspace_manager if _workspace_manager is not None: _workspace_manager.cleanup() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a51168c710ca..4d3d1ced2d57 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -512,6 +512,9 @@ class ServerArgs: ] = "none" moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" + flashinfer_allreduce_fusion_backend: Optional[ + Literal["auto", "trtllm", "mnnvl"] + ] = None enable_flashinfer_allreduce_fusion: bool = False enable_aiter_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" @@ -892,6 +895,18 @@ def _handle_deprecated_args(self): ) self.tool_call_parser = deprecated_tool_call_parsers[self.tool_call_parser] + # When user passes --enable-flashinfer-allreduce-fusion, enable with auto backend + if ( + self.enable_flashinfer_allreduce_fusion + and self.flashinfer_allreduce_fusion_backend is None + ): + logger.warning( + "--enable-flashinfer-allreduce-fusion is deprecated. " + "Please use --flashinfer-allreduce-fusion-backend=auto instead." + ) + self.flashinfer_allreduce_fusion_backend = "auto" + self.enable_flashinfer_allreduce_fusion = False + if self.enable_nan_detection: logger.warning( "--enable-nan-detection is deprecated. " @@ -1641,7 +1656,7 @@ def _handle_model_specific_adjustments(self): if is_blackwell_supported(): # workaround for https://github.com/flashinfer-ai/flashinfer/issues/2006 if not self.enable_dp_attention and self.nnodes == 1: - self.enable_flashinfer_allreduce_fusion = True + self.flashinfer_allreduce_fusion_backend = "auto" logger.info( "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM" ) @@ -1975,16 +1990,15 @@ def _handle_model_specific_adjustments(self): "Overlap scheduler is disabled when using sparse head for embedding model." ) - # TRTLLM AllReduce Fusion supports SM90/100, enable it by default - # for models with explicit support (DeepseekV3, GptOss, Glm4Moe, Qwen3Moe) - # TODO: currently, it is only supported in the single node scenario. https://github.com/flashinfer-ai/flashinfer/issues/2006 + # FlashInfer allreduce fusion: auto-enable when single-node (any SM90/100) or multi-node + Blackwell. + # See sglang.srt.layers.flashinfer_comm_fusion for backend support table (TRT-LLM vs MNNVL, SM90/100, single/multi-node). # TODO: there is currently a bug on H20 device specifically, https://github.com/flashinfer-ai/flashinfer/issues/2204 device_name = get_device_name() is_h20_device = ( device_name and "H20" in device_name and "H200" not in device_name ) if ( - not self.enable_flashinfer_allreduce_fusion + self.flashinfer_allreduce_fusion_backend is None and model_arch in [ "DeepseekV3ForCausalLM", @@ -1996,11 +2010,11 @@ def _handle_model_specific_adjustments(self): ] and (is_sm90_supported() or is_sm100_supported()) and not self.enable_dp_attention - and self.nnodes == 1 and not is_h20_device + and (self.nnodes == 1 or is_sm100_supported()) and self.moe_a2a_backend == "none" ): - self.enable_flashinfer_allreduce_fusion = True + self.flashinfer_allreduce_fusion_backend = "auto" def _handle_mamba_radix_cache( self, @@ -4600,10 +4614,21 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.flashinfer_mxfp4_moe_precision, help="Choose the computation precision of flashinfer mxfp4 moe", ) + parser.add_argument( + "--flashinfer-allreduce-fusion-backend", + type=str, + choices=["auto", "trtllm", "mnnvl"], + default=None, + help="Enable FlashInfer allreduce fusion and choose backend. When not set, the feature is disabled. " + "Options: 'auto' (choose best), 'trtllm' (SM90/100, single-node only), 'mnnvl' (SM100, single/multi-node). " + "Fuses allreduce with Residual + RMSNorm for supported MoE models.", + ) parser.add_argument( "--enable-flashinfer-allreduce-fusion", - action="store_true", - help="Enable FlashInfer allreduce fusion with Residual RMSNorm.", + action=DeprecatedStoreTrueAction, + new_flag="--flashinfer-allreduce-fusion-backend=auto", + help="(Deprecated: use --flashinfer-allreduce-fusion-backend=auto) " + "Enable FlashInfer allreduce fusion with Residual RMSNorm.", ) parser.add_argument( "--enable-aiter-allreduce-fusion", @@ -5732,6 +5757,14 @@ def check_server_args(self): f"Invalid value: '{self.served_model_name}'" ) + # FlashInfer allreduce fusion: mnnvl backend requires Blackwell (SM100) + if self.flashinfer_allreduce_fusion_backend == "mnnvl": + if not is_sm100_supported(): + raise ValueError( + "FlashInfer allreduce fusion backend 'mnnvl' is only supported on Blackwell GPUs (SM100). " + "On Hopper (SM90) or other architectures, use --flashinfer-allreduce-fusion-backend=trtllm or --flashinfer-allreduce-fusion-backend=auto instead." + ) + # Check LoRA self.check_lora_server_args()