From 2ba0a8ffbcef2ad4539128397dce0f926e1f1cd1 Mon Sep 17 00:00:00 2001 From: Shunkang <182541032+Shunkangz@users.noreply.github.co> Date: Mon, 23 Mar 2026 18:56:31 -0700 Subject: [PATCH 1/3] Support allreduce fusion with cp --- python/sglang/srt/layers/communicator.py | 5 +- .../srt/layers/flashinfer_comm_fusion.py | 200 ++++++++++++++++-- .../sglang/srt/model_executor/model_runner.py | 37 ++++ python/sglang/srt/server_args.py | 1 - 4 files changed, 216 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index a1e9a36d4c3f..0b095310243f 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -160,9 +160,6 @@ def apply_flashinfer_allreduce_fusion(batch_size: int): and not is_dp_attention_enabled() and get_global_server_args().enable_flashinfer_allreduce_fusion and not is_flashinfer_allreduce_unavailable() - # FlashInfer's TRT-LLM allreduce backend creates its own NCCL communicator - # which doesn't support PyTorch sub-process groups used by context parallelism - and get_global_server_args().attn_cp_size <= 1 ) @@ -518,7 +515,7 @@ def prepare_attn( ) and hasattr(self.input_layernorm, "forward_with_allreduce_fusion"): hidden_states, residual = ( self.input_layernorm.forward_with_allreduce_fusion( - hidden_states, residual, use_attn_tp_group=True + hidden_states, residual, use_attn_tp_group=False ) ) else: diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index 9417d573cb14..086a2a46395b 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -8,10 +8,14 @@ from sglang.srt.distributed import ( get_attn_tensor_model_parallel_rank, get_attn_tensor_model_parallel_world_size, + get_attn_tp_group, + get_moe_ep_group, get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, + get_moe_tp_group, + get_tp_group, ) from sglang.srt.environ import envs from sglang.srt.utils import is_flashinfer_available @@ -20,7 +24,7 @@ logger = logging.getLogger(__name__) _flashinfer_comm = None -_workspace_manager = None +_TorchDistBackend = None _flashinfer_allreduce_unavailable = False _posix_transport_override_logged = False @@ -111,6 +115,43 @@ def _always_disable_fabric(_device_idx: int) -> bool: "implementation" ) + try: + from flashinfer.comm.mnnvl import TorchDistBackend + + class _FixedTorchDistBackend(TorchDistBackend): + """Workaround for FlashInfer TorchDistBackend issues. + + 1. bcast fix: TorchDistBackend.bcast passes the in-group rank + directly as `src` to broadcast_object_list, which expects a + global rank. + 2. Graph-capture fix: initialize with NCCL device_group (so + the backend derives correct device_idx / GPU mapping), but + broadcast via GLOO cpu_group (to avoid NCCL collectives + that interfere with CUDA graph capture). + """ + + def __init__(self, device_group, cpu_group): + super().__init__(group=device_group) + self._cpu_group = cpu_group + + def bcast(self, data, root): + import torch.distributed as dist + + group_ranks = dist.get_process_group_ranks(self._cpu_group) + global_root = group_ranks[root] + object_list = [data] + dist.broadcast_object_list( + object_list, src=global_root, group=self._cpu_group + ) + return object_list[0] + + _TorchDistBackend = _FixedTorchDistBackend + except ImportError: + logger.debug( + "flashinfer.comm.mnnvl.TorchDistBackend is not available, " + "allreduce fusion will use the default process group" + ) + def is_flashinfer_allreduce_unavailable() -> bool: return _flashinfer_allreduce_unavailable @@ -121,6 +162,7 @@ def __init__(self): self.workspace = None self.world_size = None self.rank = None + self.group = None self.max_token_num = None self.hidden_dim = None self.dtype = None @@ -134,6 +176,8 @@ def initialize( hidden_dim: int, dtype: torch.dtype, use_oneshot: Optional[bool] = None, + device_group: Optional["torch.distributed.ProcessGroup"] = None, + cpu_group: Optional["torch.distributed.ProcessGroup"] = None, ): """Initialize workspace""" if _flashinfer_comm is None: @@ -144,15 +188,26 @@ def initialize( self.cleanup() try: + kwargs = dict( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + force_oneshot_support=bool(use_oneshot), + ) + if ( + _TorchDistBackend is not None + and device_group is not None + and cpu_group is not None + ): + kwargs["comm_backend"] = _TorchDistBackend( + device_group=device_group, cpu_group=cpu_group + ) with _flashinfer_posix_fd_transport_override_if_needed(): self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", - world_size=world_size, - rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - force_oneshot_support=bool(use_oneshot), + **kwargs ) except Exception as e: global _flashinfer_allreduce_unavailable @@ -167,6 +222,7 @@ def initialize( self.world_size = world_size self.rank = rank + self.group = (device_group, cpu_group) self.max_token_num = max_token_num self.hidden_dim = hidden_dim self.dtype = dtype @@ -211,12 +267,51 @@ def cleanup(self): self.initialized = False self.world_size = None self.rank = None + self.group = None self.max_token_num = None self.hidden_dim = None self.dtype = None -_workspace_manager = FlashInferWorkspaceManager() +_attn_tp_workspace_manager = FlashInferWorkspaceManager() +_moe_tp_workspace_manager = FlashInferWorkspaceManager() + + +def _get_workspace_manager(use_attn_tp_group: bool) -> FlashInferWorkspaceManager: + return ( + _attn_tp_workspace_manager if use_attn_tp_group else _moe_tp_workspace_manager + ) + + +def _sync_allreduce_unavailable_across_tp(): + """Synchronize _flashinfer_allreduce_unavailable across all TP ranks. + + If workspace initialization fails on any rank, all ranks must agree to + disable fusion. Otherwise ranks diverge during CUDA graph capture: some + use FlashInfer fusion (skipping custom allreduce), others fall back to + standard allreduce (calling register_buffer collectives), causing a hang + in register_graph_buffers. + """ + global _flashinfer_allreduce_unavailable + try: + import torch.distributed as dist + + tp_group = get_tp_group() + if tp_group.world_size <= 1: + return + flag = torch.tensor( + [1 if _flashinfer_allreduce_unavailable else 0], + dtype=torch.int32, + ) + dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=tp_group.cpu_group) + if flag.item() > 0 and not _flashinfer_allreduce_unavailable: + _flashinfer_allreduce_unavailable = True + logger.warning( + "FlashInfer allreduce fusion disabled globally because " + "workspace initialization failed on at least one rank." + ) + except Exception as e: + logger.debug(f"Failed to sync flashinfer unavailable flag: {e}") def ensure_workspace_initialized( @@ -234,46 +329,67 @@ def ensure_workspace_initialized( if not is_flashinfer_available() or _flashinfer_comm is None: return False + tp_coordinator = get_tp_group() + if use_attn_tp_group: world_size = get_attn_tensor_model_parallel_world_size() rank = get_attn_tensor_model_parallel_rank() + coordinator = get_attn_tp_group() else: - # If MoE expert parallel world size > 1, use expert parallel group - # Otherwise, use tensor parallel group - # The two values cannot be larger than 1 at the same time if get_moe_expert_parallel_world_size() > 1: world_size = get_moe_expert_parallel_world_size() rank = get_moe_expert_parallel_rank() + coordinator = get_moe_ep_group() else: world_size = get_moe_tensor_parallel_world_size() rank = get_moe_tensor_parallel_rank() + coordinator = get_moe_tp_group() + + # When the sub-group IS the full TP group, pass None so the workspace + # uses the default process group directly (no TorchDistBackend needed). + # For true sub-groups, use NCCL device_group for GPU/device mapping and + # GLOO cpu_group for metadata broadcasts (avoids NCCL collectives that + # interfere with CUDA graph capture). + if coordinator.device_group is tp_coordinator.device_group: + device_group = None + cpu_group = None + else: + device_group = coordinator.device_group + cpu_group = coordinator.cpu_group if world_size <= 1: return False + workspace_manager = _get_workspace_manager(use_attn_tp_group) token_num = token_num or max_token_num + group_key = (device_group, cpu_group) if ( - not _workspace_manager.initialized - or _workspace_manager.world_size != world_size - or _workspace_manager.rank != rank - or not _workspace_manager.is_buffer_size_sufficient( + not workspace_manager.initialized + or workspace_manager.world_size != world_size + or workspace_manager.rank != rank + or workspace_manager.group != group_key + or not workspace_manager.is_buffer_size_sufficient( token_num=token_num, hidden_dim=hidden_dim, dtype=dtype, use_oneshot=use_oneshot, ) ): - _workspace_manager.initialize( + workspace_manager.initialize( world_size=world_size, rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, use_oneshot=use_oneshot, + device_group=device_group, + cpu_group=cpu_group, ) - return _workspace_manager.initialized + _sync_allreduce_unavailable_across_tp() + + return workspace_manager.initialized def fake_flashinfer_allreduce_residual_rmsnorm( @@ -368,9 +484,10 @@ def flashinfer_allreduce_residual_rmsnorm( residual_out = torch.empty_like(residual) norm_out = torch.empty_like(input_tensor) + workspace_manager = _get_workspace_manager(use_attn_tp_group) _flashinfer_comm.allreduce_fusion( input=input_tensor, - workspace=_workspace_manager.workspace, + workspace=workspace_manager.workspace, pattern=_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, launch_with_pdl=True, residual_out=residual_out, @@ -385,7 +502,46 @@ def flashinfer_allreduce_residual_rmsnorm( return norm_out, residual_out +def pre_initialize_workspaces( + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + use_oneshot: Optional[bool] = None, +): + """Pre-initialize flashinfer workspaces before CUDA graph capture. + + This must be called before graph capture to avoid collective operations + (broadcasts, barriers) inside the graph capture context, which can + deadlock with custom_all_reduce.register_graph_buffers. + """ + if _flashinfer_allreduce_unavailable or _flashinfer_comm is None: + return + + # Initialize MoE workspace + ensure_workspace_initialized( + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + use_oneshot=use_oneshot, + use_attn_tp_group=False, + ) + + # Initialize attention workspace + ensure_workspace_initialized( + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + use_oneshot=use_oneshot, + use_attn_tp_group=True, + ) + + def cleanup_flashinfer_workspace(): - global _workspace_manager - if _workspace_manager is not None: - _workspace_manager.cleanup() + global _attn_tp_workspace_manager, _moe_tp_workspace_manager + if _attn_tp_workspace_manager is not None: + _attn_tp_workspace_manager.cleanup() + if ( + _moe_tp_workspace_manager is not None + and _moe_tp_workspace_manager is not _attn_tp_workspace_manager + ): + _moe_tp_workspace_manager.cleanup() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index abc308302b7a..8199fe80b8b2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2181,6 +2181,43 @@ def kernel_warmup(self): if self._should_run_flashinfer_autotune(): self._flashinfer_autotune() + self._warmup_fused_sampling() + + def _warmup_fused_sampling(self): + """Pre-compile and autotune fused sampling Triton kernels.""" + if _is_hip: + return + from sglang.srt.layers.fused_sampling import warmup_fused_temperature_softmax + + logits_warmup_dtype = ( + torch.float32 if self.server_args.enable_fp32_lm_head else self.dtype + ) + warmup_fused_temperature_softmax( + self.model_config.vocab_size, + logits_dtype=logits_warmup_dtype, + ) + + def _pre_initialize_flashinfer_allreduce_workspace(self): + """Pre-initialize flashinfer allreduce fusion workspaces. + + Must run before CUDA graph capture to avoid collective operations + (broadcasts, barriers) inside the graph capture context, which can + deadlock with custom_all_reduce.register_graph_buffers. + """ + if not self.server_args.enable_flashinfer_allreduce_fusion: + return + + from sglang.srt.layers.communicator import FUSE_ALLREDUCE_MAX_BATCH_SIZE + from sglang.srt.layers.flashinfer_comm_fusion import ( + pre_initialize_workspaces, + ) + + pre_initialize_workspaces( + max_token_num=FUSE_ALLREDUCE_MAX_BATCH_SIZE, + hidden_dim=self.model_config.hidden_size, + dtype=self.dtype, + ) + def _should_run_flashinfer_autotune(self) -> bool: """Check if flashinfer autotune should be run.""" if self.server_args.disable_flashinfer_autotune: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2390b3414562..b606a0ce741d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2214,7 +2214,6 @@ def _handle_model_specific_adjustments(self): and (is_sm90_supported() or is_sm100_supported()) and self.tp_size > 1 and not self.enable_dp_attention - and self.attn_cp_size <= 1 and self.nnodes == 1 and not is_h20_device and self.moe_a2a_backend == "none" From 4e13e02f8997ab128b8bd081df65c1efd402a304 Mon Sep 17 00:00:00 2001 From: Shunkang <182541032+Shunkangz@users.noreply.github.co> Date: Wed, 8 Apr 2026 01:08:10 -0700 Subject: [PATCH 2/3] Fix deepseekv3 issue --- python/sglang/srt/model_executor/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8199fe80b8b2..9499aae9aa94 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -719,6 +719,7 @@ def initialize(self, pre_model_load_memory: float): self.init_cublas() self.init_attention_backend() self.kernel_warmup() + self._pre_initialize_flashinfer_allreduce_workspace() self.init_device_graphs() elif self.device in ["npu", "cpu"]: self.init_attention_backend() From 6ef849954df6beb115e6f80e13d1f61b38902577 Mon Sep 17 00:00:00 2001 From: Shunkang <182541032+Shunkangz@users.noreply.github.co> Date: Mon, 13 Apr 2026 06:31:26 -0700 Subject: [PATCH 3/3] Fix the merge conflict issue --- python/sglang/srt/model_executor/model_runner.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9499aae9aa94..a17034ea979f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2182,22 +2182,6 @@ def kernel_warmup(self): if self._should_run_flashinfer_autotune(): self._flashinfer_autotune() - self._warmup_fused_sampling() - - def _warmup_fused_sampling(self): - """Pre-compile and autotune fused sampling Triton kernels.""" - if _is_hip: - return - from sglang.srt.layers.fused_sampling import warmup_fused_temperature_softmax - - logits_warmup_dtype = ( - torch.float32 if self.server_args.enable_fp32_lm_head else self.dtype - ) - warmup_fused_temperature_softmax( - self.model_config.vocab_size, - logits_dtype=logits_warmup_dtype, - ) - def _pre_initialize_flashinfer_allreduce_workspace(self): """Pre-initialize flashinfer allreduce fusion workspaces.