Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def apply_flashinfer_allreduce_fusion(batch_size: int):
and not is_dp_attention_enabled()
and get_global_server_args().enable_flashinfer_allreduce_fusion
and not is_flashinfer_allreduce_unavailable()
# FlashInfer's TRT-LLM allreduce backend creates its own NCCL communicator
# which doesn't support PyTorch sub-process groups used by context parallelism
and get_global_server_args().attn_cp_size <= 1
)


Expand Down Expand Up @@ -518,7 +515,7 @@ def prepare_attn(
) and hasattr(self.input_layernorm, "forward_with_allreduce_fusion"):
hidden_states, residual = (
self.input_layernorm.forward_with_allreduce_fusion(
hidden_states, residual, use_attn_tp_group=True
hidden_states, residual, use_attn_tp_group=False
Comment thread
Fridge003 marked this conversation as resolved.
)
)
else:
Expand Down
200 changes: 178 additions & 22 deletions python/sglang/srt/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
from sglang.srt.distributed import (
get_attn_tensor_model_parallel_rank,
get_attn_tensor_model_parallel_world_size,
get_attn_tp_group,
get_moe_ep_group,
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_moe_tp_group,
get_tp_group,
)
from sglang.srt.environ import envs
from sglang.srt.utils import is_flashinfer_available
Expand All @@ -20,7 +24,7 @@
logger = logging.getLogger(__name__)

_flashinfer_comm = None
_workspace_manager = None
_TorchDistBackend = None
_flashinfer_allreduce_unavailable = False
_posix_transport_override_logged = False

Expand Down Expand Up @@ -111,6 +115,43 @@ def _always_disable_fabric(_device_idx: int) -> bool:
"implementation"
)

try:
from flashinfer.comm.mnnvl import TorchDistBackend

class _FixedTorchDistBackend(TorchDistBackend):
"""Workaround for FlashInfer TorchDistBackend issues.

1. bcast fix: TorchDistBackend.bcast passes the in-group rank
directly as `src` to broadcast_object_list, which expects a
global rank.
2. Graph-capture fix: initialize with NCCL device_group (so
the backend derives correct device_idx / GPU mapping), but
broadcast via GLOO cpu_group (to avoid NCCL collectives
that interfere with CUDA graph capture).
"""

def __init__(self, device_group, cpu_group):
super().__init__(group=device_group)
self._cpu_group = cpu_group

def bcast(self, data, root):
import torch.distributed as dist

group_ranks = dist.get_process_group_ranks(self._cpu_group)
global_root = group_ranks[root]
object_list = [data]
dist.broadcast_object_list(
object_list, src=global_root, group=self._cpu_group
)
return object_list[0]

_TorchDistBackend = _FixedTorchDistBackend
except ImportError:
logger.debug(
"flashinfer.comm.mnnvl.TorchDistBackend is not available, "
"allreduce fusion will use the default process group"
)


def is_flashinfer_allreduce_unavailable() -> bool:
return _flashinfer_allreduce_unavailable
Expand All @@ -121,6 +162,7 @@ def __init__(self):
self.workspace = None
self.world_size = None
self.rank = None
self.group = None
self.max_token_num = None
self.hidden_dim = None
self.dtype = None
Expand All @@ -134,6 +176,8 @@ def initialize(
hidden_dim: int,
dtype: torch.dtype,
use_oneshot: Optional[bool] = None,
device_group: Optional["torch.distributed.ProcessGroup"] = None,
cpu_group: Optional["torch.distributed.ProcessGroup"] = None,
):
"""Initialize workspace"""
if _flashinfer_comm is None:
Expand All @@ -144,15 +188,26 @@ def initialize(

self.cleanup()
try:
kwargs = dict(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
force_oneshot_support=bool(use_oneshot),
)
if (
_TorchDistBackend is not None
and device_group is not None
and cpu_group is not None
):
kwargs["comm_backend"] = _TorchDistBackend(
device_group=device_group, cpu_group=cpu_group
)
with _flashinfer_posix_fd_transport_override_if_needed():
self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
force_oneshot_support=bool(use_oneshot),
**kwargs
)
except Exception as e:
global _flashinfer_allreduce_unavailable
Expand All @@ -167,6 +222,7 @@ def initialize(

self.world_size = world_size
self.rank = rank
self.group = (device_group, cpu_group)
self.max_token_num = max_token_num
self.hidden_dim = hidden_dim
self.dtype = dtype
Expand Down Expand Up @@ -211,12 +267,51 @@ def cleanup(self):
self.initialized = False
self.world_size = None
self.rank = None
self.group = None
self.max_token_num = None
self.hidden_dim = None
self.dtype = None


_workspace_manager = FlashInferWorkspaceManager()
_attn_tp_workspace_manager = FlashInferWorkspaceManager()
Comment thread
Shunkangz marked this conversation as resolved.
_moe_tp_workspace_manager = FlashInferWorkspaceManager()


def _get_workspace_manager(use_attn_tp_group: bool) -> FlashInferWorkspaceManager:
return (
_attn_tp_workspace_manager if use_attn_tp_group else _moe_tp_workspace_manager
)


def _sync_allreduce_unavailable_across_tp():
"""Synchronize _flashinfer_allreduce_unavailable across all TP ranks.

If workspace initialization fails on any rank, all ranks must agree to
disable fusion. Otherwise ranks diverge during CUDA graph capture: some
use FlashInfer fusion (skipping custom allreduce), others fall back to
standard allreduce (calling register_buffer collectives), causing a hang
in register_graph_buffers.
"""
global _flashinfer_allreduce_unavailable
try:
import torch.distributed as dist

tp_group = get_tp_group()
if tp_group.world_size <= 1:
return
flag = torch.tensor(
[1 if _flashinfer_allreduce_unavailable else 0],
dtype=torch.int32,
)
dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=tp_group.cpu_group)
if flag.item() > 0 and not _flashinfer_allreduce_unavailable:
_flashinfer_allreduce_unavailable = True
logger.warning(
"FlashInfer allreduce fusion disabled globally because "
"workspace initialization failed on at least one rank."
)
except Exception as e:
logger.debug(f"Failed to sync flashinfer unavailable flag: {e}")


def ensure_workspace_initialized(
Expand All @@ -234,46 +329,67 @@ def ensure_workspace_initialized(
if not is_flashinfer_available() or _flashinfer_comm is None:
return False

tp_coordinator = get_tp_group()

if use_attn_tp_group:
world_size = get_attn_tensor_model_parallel_world_size()
rank = get_attn_tensor_model_parallel_rank()
coordinator = get_attn_tp_group()
else:
# If MoE expert parallel world size > 1, use expert parallel group
# Otherwise, use tensor parallel group
# The two values cannot be larger than 1 at the same time
if get_moe_expert_parallel_world_size() > 1:
world_size = get_moe_expert_parallel_world_size()
rank = get_moe_expert_parallel_rank()
coordinator = get_moe_ep_group()
else:
world_size = get_moe_tensor_parallel_world_size()
rank = get_moe_tensor_parallel_rank()
coordinator = get_moe_tp_group()

# When the sub-group IS the full TP group, pass None so the workspace
# uses the default process group directly (no TorchDistBackend needed).
# For true sub-groups, use NCCL device_group for GPU/device mapping and
# GLOO cpu_group for metadata broadcasts (avoids NCCL collectives that
# interfere with CUDA graph capture).
if coordinator.device_group is tp_coordinator.device_group:
Comment thread
Fridge003 marked this conversation as resolved.
device_group = None
cpu_group = None
else:
device_group = coordinator.device_group
cpu_group = coordinator.cpu_group

if world_size <= 1:
return False

workspace_manager = _get_workspace_manager(use_attn_tp_group)
token_num = token_num or max_token_num
group_key = (device_group, cpu_group)

if (
not _workspace_manager.initialized
or _workspace_manager.world_size != world_size
or _workspace_manager.rank != rank
or not _workspace_manager.is_buffer_size_sufficient(
not workspace_manager.initialized
or workspace_manager.world_size != world_size
or workspace_manager.rank != rank
or workspace_manager.group != group_key
or not workspace_manager.is_buffer_size_sufficient(
token_num=token_num,
hidden_dim=hidden_dim,
dtype=dtype,
use_oneshot=use_oneshot,
)
):
_workspace_manager.initialize(
workspace_manager.initialize(
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
use_oneshot=use_oneshot,
device_group=device_group,
cpu_group=cpu_group,
)

return _workspace_manager.initialized
_sync_allreduce_unavailable_across_tp()

return workspace_manager.initialized


def fake_flashinfer_allreduce_residual_rmsnorm(
Expand Down Expand Up @@ -368,9 +484,10 @@ def flashinfer_allreduce_residual_rmsnorm(
residual_out = torch.empty_like(residual)
norm_out = torch.empty_like(input_tensor)

workspace_manager = _get_workspace_manager(use_attn_tp_group)
_flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_workspace_manager.workspace,
workspace=workspace_manager.workspace,
pattern=_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
launch_with_pdl=True,
residual_out=residual_out,
Expand All @@ -385,7 +502,46 @@ def flashinfer_allreduce_residual_rmsnorm(
return norm_out, residual_out


def pre_initialize_workspaces(
Comment thread
Shunkangz marked this conversation as resolved.
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
use_oneshot: Optional[bool] = None,
):
"""Pre-initialize flashinfer workspaces before CUDA graph capture.

This must be called before graph capture to avoid collective operations
(broadcasts, barriers) inside the graph capture context, which can
deadlock with custom_all_reduce.register_graph_buffers.
"""
if _flashinfer_allreduce_unavailable or _flashinfer_comm is None:
return

# Initialize MoE workspace
ensure_workspace_initialized(
Comment thread
Shunkangz marked this conversation as resolved.
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
use_oneshot=use_oneshot,
use_attn_tp_group=False,
)

# Initialize attention workspace
ensure_workspace_initialized(
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
use_oneshot=use_oneshot,
use_attn_tp_group=True,
)


def cleanup_flashinfer_workspace():
global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
global _attn_tp_workspace_manager, _moe_tp_workspace_manager
if _attn_tp_workspace_manager is not None:
_attn_tp_workspace_manager.cleanup()
if (
_moe_tp_workspace_manager is not None
and _moe_tp_workspace_manager is not _attn_tp_workspace_manager
):
_moe_tp_workspace_manager.cleanup()
22 changes: 22 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def initialize(self, pre_model_load_memory: float):
self.init_cublas()
self.init_attention_backend()
self.kernel_warmup()
self._pre_initialize_flashinfer_allreduce_workspace()
self.init_device_graphs()
elif self.device in ["npu", "cpu"]:
self.init_attention_backend()
Expand Down Expand Up @@ -2181,6 +2182,27 @@ def kernel_warmup(self):
if self._should_run_flashinfer_autotune():
self._flashinfer_autotune()

def _pre_initialize_flashinfer_allreduce_workspace(self):
Comment thread
Fridge003 marked this conversation as resolved.
"""Pre-initialize flashinfer allreduce fusion workspaces.

Must run before CUDA graph capture to avoid collective operations
(broadcasts, barriers) inside the graph capture context, which can
deadlock with custom_all_reduce.register_graph_buffers.
"""
if not self.server_args.enable_flashinfer_allreduce_fusion:
return

from sglang.srt.layers.communicator import FUSE_ALLREDUCE_MAX_BATCH_SIZE
from sglang.srt.layers.flashinfer_comm_fusion import (
pre_initialize_workspaces,
)

pre_initialize_workspaces(
max_token_num=FUSE_ALLREDUCE_MAX_BATCH_SIZE,
hidden_dim=self.model_config.hidden_size,
dtype=self.dtype,
)

def _should_run_flashinfer_autotune(self) -> bool:
"""Check if flashinfer autotune should be run."""
if self.server_args.disable_flashinfer_autotune:
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,7 +2214,6 @@ def _handle_model_specific_adjustments(self):
and (is_sm90_supported() or is_sm100_supported())
and self.tp_size > 1
and not self.enable_dp_attention
and self.attn_cp_size <= 1
Comment thread
Shunkangz marked this conversation as resolved.
and self.nnodes == 1
and not is_h20_device
and self.moe_a2a_backend == "none"
Expand Down
Loading