From 06e12204863e8ee97028d2bb3d6661ef88631664 Mon Sep 17 00:00:00 2001 From: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com> Date: Tue, 17 Mar 2026 03:19:53 +0800 Subject: [PATCH] upd: clean and fix --- python/sglang/srt/environ.py | 3 + .../srt/layers/flashinfer_comm_fusion.py | 89 +++++++++++++++++-- 2 files changed, 83 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index e9c6481ba9f8..7d738cbcbaa8 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -341,6 +341,9 @@ class Envs: # Default to the pick from flashinfer SGLANG_FLASHINFER_FP4_GEMM_BACKEND = EnvStr("") SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024) + # TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion + # transport issue on GB200/GB300 platforms is fixed and verified resolved. + SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None) # Triton SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index c45ff977fcae..9f7eb2461057 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -1,4 +1,6 @@ +import contextlib import logging +import platform from typing import Optional, Tuple import torch @@ -7,6 +9,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from sglang.srt.environ import envs from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils.custom_op import register_custom_op @@ -15,6 +18,73 @@ _flashinfer_comm = None _workspace_manager = None _flashinfer_allreduce_unavailable = False +_posix_transport_override_logged = False + + +def _should_force_posix_fd_transport() -> bool: + force_posix_env = envs.SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT.get() + if force_posix_env is not None: + return force_posix_env + + machine = platform.machine().lower() + if machine not in ("aarch64", "arm64"): + return False + + if not torch.cuda.is_available(): + return False + + try: + major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device()) + except Exception as e: + logger.debug("Failed to get CUDA device capability: %s", e) + return False + + return major == 10 + + +@contextlib.contextmanager +def _flashinfer_posix_fd_transport_override_if_needed(): + # TODO(mmangkad): Remove this temporary override once the + # FlashInfer unified allreduce-fusion transport issue on + # GB200/GB300 platforms is fixed and verified resolved. + global _posix_transport_override_logged + + if not _should_force_posix_fd_transport(): + yield + return + + try: + import flashinfer.comm.mnnvl as flashinfer_mnnvl + except Exception as e: + logger.debug( + "Failed to import flashinfer.comm.mnnvl for transport override: %s", e + ) + yield + return + + original_checker = getattr(flashinfer_mnnvl, "is_mnnvl_fabric_supported", None) + if original_checker is None: + yield + return + + if not _posix_transport_override_logged: + logger.warning( + "Applying FlashInfer transport workaround: forcing PosixFD " + "symmetric-memory handle exchange on aarch64 + sm10x to avoid " + "known data corruption with Fabric handle exchange on GB systems. " + "Set SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT=0 to disable." + ) + _posix_transport_override_logged = True + + def _always_disable_fabric(_device_idx: int) -> bool: + return False + + flashinfer_mnnvl.is_mnnvl_fabric_supported = _always_disable_fabric + try: + yield + finally: + flashinfer_mnnvl.is_mnnvl_fabric_supported = original_checker + if is_flashinfer_available(): try: @@ -70,15 +140,16 @@ def initialize( self.cleanup() try: - self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", - world_size=world_size, - rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - force_oneshot_support=bool(use_oneshot), - ) + with _flashinfer_posix_fd_transport_override_if_needed(): + self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + force_oneshot_support=bool(use_oneshot), + ) except Exception as e: global _flashinfer_allreduce_unavailable _flashinfer_allreduce_unavailable = True