Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ class Envs:
# Default to the pick from flashinfer
SGLANG_FLASHINFER_FP4_GEMM_BACKEND = EnvStr("")
SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024)
# TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion
# transport issue on GB200/GB300 platforms is fixed and verified resolved.
SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None)

# Triton
SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False)
Expand Down
89 changes: 80 additions & 9 deletions python/sglang/srt/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import contextlib
import logging
import platform
from typing import Optional, Tuple

import torch
Expand All @@ -7,6 +9,7 @@
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.environ import envs
from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils.custom_op import register_custom_op

Expand All @@ -15,6 +18,73 @@
_flashinfer_comm = None
_workspace_manager = None
_flashinfer_allreduce_unavailable = False
_posix_transport_override_logged = False


def _should_force_posix_fd_transport() -> bool:
force_posix_env = envs.SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT.get()
if force_posix_env is not None:
return force_posix_env

machine = platform.machine().lower()
if machine not in ("aarch64", "arm64"):
return False

if not torch.cuda.is_available():
return False

try:
major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device())
except Exception as e:
logger.debug("Failed to get CUDA device capability: %s", e)
return False

return major == 10


@contextlib.contextmanager
def _flashinfer_posix_fd_transport_override_if_needed():
# TODO(mmangkad): Remove this temporary override once the
# FlashInfer unified allreduce-fusion transport issue on
# GB200/GB300 platforms is fixed and verified resolved.
global _posix_transport_override_logged

if not _should_force_posix_fd_transport():
yield
return

try:
import flashinfer.comm.mnnvl as flashinfer_mnnvl
except Exception as e:
logger.debug(
"Failed to import flashinfer.comm.mnnvl for transport override: %s", e
)
yield
return

original_checker = getattr(flashinfer_mnnvl, "is_mnnvl_fabric_supported", None)
if original_checker is None:
yield
return

if not _posix_transport_override_logged:
logger.warning(
"Applying FlashInfer transport workaround: forcing PosixFD "
"symmetric-memory handle exchange on aarch64 + sm10x to avoid "
"known data corruption with Fabric handle exchange on GB systems. "
"Set SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT=0 to disable."
)
_posix_transport_override_logged = True

def _always_disable_fabric(_device_idx: int) -> bool:
return False

flashinfer_mnnvl.is_mnnvl_fabric_supported = _always_disable_fabric
try:
yield
finally:
flashinfer_mnnvl.is_mnnvl_fabric_supported = original_checker


if is_flashinfer_available():
try:
Expand Down Expand Up @@ -70,15 +140,16 @@ def initialize(

self.cleanup()
try:
self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
force_oneshot_support=bool(use_oneshot),
)
with _flashinfer_posix_fd_transport_override_if_needed():
self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
force_oneshot_support=bool(use_oneshot),
)
except Exception as e:
global _flashinfer_allreduce_unavailable
_flashinfer_allreduce_unavailable = True
Expand Down
Loading