From 874c01459a905e16cb6fe99a995f038a1f4bc183 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Sat, 21 Mar 2026 22:07:18 +0800 Subject: [PATCH 1/6] refactor: clean up the parallel state --- .../kernels/all_reduce/benchmark_aiter.py | 2 + .../benchmark/bench_custom_all_reduce.py | 6 +- .../tests/test_custom_all_reduce.py | 6 +- .../srt/compilation/compilation_config.py | 4 +- .../device_communicators/__init__.py | 26 + .../device_communicators/all_reduce_utils.py | 16 - .../distributed/device_communicators/base.py | 264 ++++++++ .../device_communicators/custom_all_reduce.py | 97 +-- .../custom_all_reduce_aiter.py | 45 ++ .../custom_all_reduce_utils.py | 4 +- .../custom_all_reduce_v2.py | 60 +- .../device_communicators/hpu_communicator.py | 62 +- .../distributed/device_communicators/impl.py | 171 +++++ .../device_communicators/npu_communicator.py | 39 -- .../device_communicators/pymscclpp.py | 95 ++- .../device_communicators/pynccl.py | 194 +++--- .../device_communicators/pynccl_symm.py | 80 +++ .../device_communicators/quick_all_reduce.py | 118 ++-- .../device_communicators/torch_symm_mem.py | 108 +-- .../device_communicators/torch_wrapper.py | 95 +++ .../device_communicators/xpu_communicator.py | 43 +- .../sglang/srt/distributed/parallel_state.py | 618 ++++++------------ python/sglang/srt/layers/dp_attention.py | 7 +- .../srt/model_executor/mindspore_runner.py | 4 +- .../bench_amd_deterministic_allreduce.py | 2 +- ...test_amd_deterministic_custom_allreduce.py | 4 +- 26 files changed, 1282 insertions(+), 888 deletions(-) create mode 100644 python/sglang/srt/distributed/device_communicators/__init__.py delete mode 100644 python/sglang/srt/distributed/device_communicators/all_reduce_utils.py create mode 100644 python/sglang/srt/distributed/device_communicators/base.py create mode 100644 python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py create mode 100644 python/sglang/srt/distributed/device_communicators/impl.py delete mode 100644 python/sglang/srt/distributed/device_communicators/npu_communicator.py create mode 100644 python/sglang/srt/distributed/device_communicators/pynccl_symm.py create mode 100644 python/sglang/srt/distributed/device_communicators/torch_wrapper.py diff --git a/benchmark/kernels/all_reduce/benchmark_aiter.py b/benchmark/kernels/all_reduce/benchmark_aiter.py index bca45620784a..4943969c171d 100644 --- a/benchmark/kernels/all_reduce/benchmark_aiter.py +++ b/benchmark/kernels/all_reduce/benchmark_aiter.py @@ -107,6 +107,8 @@ def run_once(comm, inp: torch.Tensor) -> Optional[torch.Tensor]: return comm.all_reduce_unreg(inp) if hasattr(comm, "custom_all_reduce"): return comm.custom_all_reduce(inp) + if hasattr(comm, "all_reduce"): + return comm.all_reduce(inp) raise RuntimeError("No known all-reduce method found on the communicator.") diff --git a/python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py b/python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py index 4f36f1a48276..2bbe1301b93a 100644 --- a/python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py +++ b/python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py @@ -97,8 +97,7 @@ def capture(self, register_input: bool): return self.comm.capture() # ignore register_input since v1 always requires it def all_reduce(self, tensor: torch.Tensor) -> Optional[torch.Tensor]: - assert self.comm.should_custom_ar(tensor), str(tensor.shape) - return self.comm.custom_all_reduce(tensor) + return self.comm.all_reduce(tensor) class JITAllReduceBackend: @@ -118,8 +117,7 @@ def capture(self, register_input: bool): return self.comm.capture() if register_input else contextlib.nullcontext() def all_reduce(self, tensor: torch.Tensor) -> Optional[torch.Tensor]: - assert self.comm.should_custom_ar(tensor), str(tensor.shape) - return self.comm.custom_all_reduce(tensor) + return self.comm.all_reduce(tensor) class FlashInferAllReduceBackend: diff --git a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py index 2d7e0253eb43..7f48fa6e0298 100644 --- a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py +++ b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py @@ -186,7 +186,7 @@ def get_run_graph_fn(): with comm.capture(): with torch.cuda.graph(graph): for i in range(TEST_LAYERS): - out_jits.append(comm.custom_all_reduce(graph_inp[i])) + out_jits.append(comm.all_reduce(graph_inp[i])) out_jit = torch.stack(out_jits) torch.cuda.synchronize() @@ -202,7 +202,7 @@ def run_eager(x: torch.Tensor) -> torch.Tensor: eager_inp = x.clone() out_eagers = [] for i in range(TEST_LAYERS): - out_eagers.append(comm.custom_all_reduce(eager_inp[i])) + out_eagers.append(comm.all_reduce(eager_inp[i])) torch.cuda.synchronize() return torch.stack(out_eagers) @@ -213,7 +213,7 @@ def run_eager(x: torch.Tensor) -> torch.Tensor: for _ in range(TEST_LOOP): # NOTE: 15 * 8 < 128, which is the precision limit for bf16 inp = torch.randint(0, 16, (TEST_LAYERS, size), dtype=dtype, device=device) - assert comm.should_custom_ar(inp[0]) + assert comm.can_all_reduce(inp[0]) out_ref = inp.clone() dist.all_reduce(out_ref, group=nccl_group) out_jit = run_fn(inp) diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index 0388bbedac06..c542e564867b 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -1,12 +1,12 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py -from typing import Callable, List, Optional +from typing import List, Optional SPLIT_OPS = [] def register_split_op(op_name: Optional[str] = None): - def decorator(op_func: Callable): + def decorator(op_func): name = op_name or op_func.__name__ SPLIT_OPS.append(f"sglang.{name}") return op_func diff --git a/python/sglang/srt/distributed/device_communicators/__init__.py b/python/sglang/srt/distributed/device_communicators/__init__.py new file mode 100644 index 000000000000..af4aa8f82bf4 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/__init__.py @@ -0,0 +1,26 @@ +from .custom_all_reduce import dispatch_custom_allreduce +from .hpu_communicator import HpuCommunicator +from .impl import CommunicatorImpl +from .pymscclpp import PyMscclppCommunicator +from .pynccl import PyNcclCommunicator +from .pynccl_symm import PyNcclSymmMemCommunicator +from .quick_all_reduce import QuickAllReduce, qr_rocm_arch_available +from .shm_broadcast import MessageQueue +from .torch_symm_mem import TorchSymmMemCommunicator +from .torch_wrapper import TorchDefaultCommunicator +from .xpu_communicator import XpuCommunicator + +__all__ = [ + "PyMscclppCommunicator", + "PyNcclCommunicator", + "PyNcclSymmMemCommunicator", + "TorchSymmMemCommunicator", + "TorchDefaultCommunicator", + "QuickAllReduce", + "HpuCommunicator", + "XpuCommunicator", + "MessageQueue", + "dispatch_custom_allreduce", + "qr_rocm_arch_available", + "CommunicatorImpl", +] diff --git a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py deleted file mode 100644 index ff88b6a1dcfd..000000000000 --- a/python/sglang/srt/distributed/device_communicators/all_reduce_utils.py +++ /dev/null @@ -1,16 +0,0 @@ -MiB = 1024 * 1024 - -TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES = { - 9: { - 2: 64 * MiB, # 64 MB - 4: 64 * MiB, # 64 MB - 6: 128 * MiB, # 128 MB - 8: 128 * MiB, # 128 MB - }, - 10: { - 2: 64 * MiB, # 64 MB - 4: 64 * MiB, # 64 MB - 6: 128 * MiB, # 128 MB - 8: 128 * MiB, # 128 MB - }, -} diff --git a/python/sglang/srt/distributed/device_communicators/base.py b/python/sglang/srt/distributed/device_communicators/base.py new file mode 100644 index 000000000000..656721b54ecc --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/base.py @@ -0,0 +1,264 @@ +import contextlib +import enum +import functools +from typing import Any, ContextManager, List, Optional, Tuple + +import torch + + +# NOTE: Use concat-style all-gather here. +# Stack-style all-gather has compatibility issues with `torch.compile`. +# See https://github.com/pytorch/pytorch/issues/138795. +def allocate_all_gather(input_: torch.Tensor, world_size: int) -> torch.Tensor: + input_shape = input_.shape + return torch.empty( + (world_size * input_shape[0],) + input_shape[1:], + dtype=input_.dtype, + device=input_.device, + ) + + +def allocate_reduce_scatter(input_: torch.Tensor, world_size: int) -> torch.Tensor: + input_shape = input_.shape + assert input_shape[0] % world_size == 0 + return torch.empty( + (input_shape[0] // world_size,) + input_shape[1:], + dtype=input_.dtype, + device=input_.device, + ) + + +class AllReduceMode(enum.Enum): + BOTH = "both" + INPLACE = "inplace" + OUTPLACE = "outplace" + + +class BaseCommunicator: + name: str # should be set by subclass + + def __init__(self, world_size: int, disabled: bool = False): + self.world_size = world_size + self._disabled = disabled # NOTE: must use `change_state` to modify + + # Helper functions + + def assert_inplace(self, op: str, inplace: Optional[bool]): + if inplace == False: + raise ValueError(f"{self.name} does not allow out-of-place {op} now") + + def assert_outplace(self, op: str, inplace: Optional[bool]): + if inplace == True: + raise ValueError(f"{self.name} does not allow in-place {op} now") + + @staticmethod + def validate(f): + """ + Guard a public communicator method against calls while the communicator is + disabled. + """ + + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + if self.disabled: + raise RuntimeError(f"{self.name} is disabled") + return f(self, *args, **kwargs) + + return wrapper + + def allocate_all_gather(self, input_: torch.Tensor) -> torch.Tensor: + return allocate_all_gather(input_, self.world_size) + + def allocate_reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor: + return allocate_reduce_scatter(input_, self.world_size) + + # Public API + + @property + def disabled(self) -> bool: + """ + Whether this communicator is currently disabled. + + Public methods on this interface should not be called while the + communicator is disabled. Subclasses may override this property to add + derived enablement conditions on top of `_disabled`. + + Do not modify `self._disabled` directly outside this class. Use + `change_state()` instead. + """ + return self._disabled + + @contextlib.contextmanager + def change_state(self, enable: bool): + """ + Temporarily enable or disable the communicator within a context. + + :param enable: Whether the communicator should be enabled in the + wrapped block. + """ + old_value = self._disabled + self._disabled = not enable + try: + yield + finally: + self._disabled = old_value + + def graph_capture_context(self) -> Optional[ContextManager[Any]]: + """ + Return a context manager for graph capture, if the communicator needs + special handling during capture. + + Returning `None` means no extra handling is required. + """ + return None + + def should_use_custom_op(self) -> bool: + """ + Whether this communicator should use `register_custom_op` for `torch.compile` + compatibility. + If `False`, this means either: + 1. This backend doesn't support `torch.compile` + 2. This implementation is `torch.compile` friendly + """ + return False + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + """ + Report the preferred all-reduce mode for `input_`. + + :param input_: Input tensor for the all-reduce. + :return: + - `AllReduceMode.INPLACE` if in-place all-reduce is preferred. + - `AllReduceMode.OUTPLACE` if out-of-place all-reduce is preferred. + - `AllReduceMode.BOTH` if both modes are fine. + - `None` if the communicator cannot run all-reduce on `input_` (e.g., + due to unsupported dtype, shape or alignment). + + This is orthogonal to `self.disabled`, which covers broader reasons why + the communicator is unavailable. + """ + raise NotImplementedError() + + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + """ + Run all-reduce on `input_`. + + Preconditions: + 1. `self.can_all_reduce(input_)` must not return `None`. + 2. `inplace=True` requires `self.can_all_reduce(input_).can_inplace()`. + 3. `inplace=False` requires `self.can_all_reduce(input_).can_outplace()`. + 4. `self.disabled` must be `False`. + + :param input_: Input tensor for the all-reduce. + :param inplace: Whether the operation should be in-place. If `None`, the + communicator may choose its preferred mode. If specified, it must be + consistent with `can_all_reduce(input_)`. + :return: The reduced tensor. If the operation is in-place, this must be + `input_` itself. + """ + raise NotImplementedError() + + def fused_all_reduce_rmsnorm( + self, + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Run fused all-reduce followed by RMSNorm. + + :param input_: Input tensor to reduce and normalize. + :param residual: Residual tensor used by RMSNorm. + :param weight: RMSNorm weight tensor. + :param eps: RMSNorm epsilon. + """ + raise NotImplementedError() + + def all_gather_into_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Run concat-style all-gather on `input_`. + + :param input_: Input tensor for all-gather. + :param out: Optional preallocated output tensor. If omitted, the + communicator may allocate and return one. + :return: The gathered tensor. If `out` is provided, this must be `out`. + """ + raise NotImplementedError() + + def all_gather( + self, + input_: torch.Tensor, + *, + out_list: Optional[List[torch.Tensor]] = None, + ) -> List[torch.Tensor]: + """ + Run list-style all-gather on `input_`. + + :param input_: Input tensor for all-gather. + :param out_list: Optional preallocated output list. If provided, the + communicator should fill and return it. + :return: The gathered tensor list. If `out_list` is provided, this must + be `out_list`. + """ + raise NotImplementedError() + + def reduce_scatter_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Run tensor-style reduce-scatter on `input_`. + + :param input_: Input tensor for reduce-scatter. + :param out: Optional preallocated output tensor. If omitted, the + communicator may allocate and return one. + :return: The reduced shard. If `out` is provided, this must be `out`. + """ + raise NotImplementedError() + + def reduce_scatter( + self, + input_list: List[torch.Tensor], + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Run list-style reduce-scatter on `input_list`. + + :param input_list: Input tensor list for reduce-scatter. + :param out: Optional preallocated output tensor. If omitted, the + communicator may allocate and return one. + :return: The reduced shard. If `out` is provided, this must be `out`. + """ + raise NotImplementedError() + + def gather( + self, + input_: torch.Tensor, + dst: int, + *, + dim: int = 0, + ) -> Optional[torch.Tensor]: + """ + Gather `input_` to the destination rank. + + :param input_: Input tensor for gather. + :param dst: Destination rank within the communicator. + :param dim: Concatenation dimension in the returned tensor on the + destination rank. + :return: The gathered tensor on the destination rank, otherwise `None`. + """ + raise NotImplementedError() diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index c9e055bb26dd..929c4cd4e391 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -3,7 +3,6 @@ import ctypes import logging from contextlib import contextmanager -from functools import partial from typing import Any, List, Optional, Union import torch @@ -12,11 +11,6 @@ import sglang.srt.distributed.device_communicators.custom_all_reduce_ops as ops from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph -from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary -from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( - can_use_custom_all_reduce_with_nvlink, - is_weak_contiguous, -) from sglang.srt.environ import envs from sglang.srt.utils import ( get_bool_env_var, @@ -26,6 +20,13 @@ log_info_on_rank0, ) +from .base import AllReduceMode, BaseCommunicator +from .cuda_wrapper import CudaRTLibrary +from .custom_all_reduce_utils import ( + can_use_custom_all_reduce_with_nvlink, + is_weak_contiguous, +) + _is_cuda = is_cuda() _is_hip = is_hip() _is_musa = is_musa() @@ -33,7 +34,9 @@ logger = logging.getLogger(__name__) -class CustomAllreduce: +class CustomAllreduce(BaseCommunicator): + name = "custom_all_reduce" + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _MAX_CAR_SIZE = 8192 * 1024 if _is_hip: @@ -61,17 +64,14 @@ def __init__( are in the same node. """ self._IS_CAPTURING = False - self.disabled = True # This can be modified in-place by context manager in piecewise cuda graph runner - self.original_disabled = True # To store the original state self.use_amd_deterministic_impl = _use_amd_deterministic_impl() if not ops.IS_CUSTOM_AR_AVAILABLE: - # disable because of missing custom allreduce library - # e.g. in a non-cuda environment - return + raise RuntimeError("custom all-reduce library is not available") rank = dist.get_rank(group=group) world_size = dist.get_world_size(group=group) + assert world_size > 1 if isinstance(device, int): device = torch.device(f"cuda:{device}") @@ -87,13 +87,14 @@ def __init__( cls_name="CustomAllreduce", ) if full_nvlink is None: - return # fail to get nvlink status + raise RuntimeError("failed to determine NVLink connectivity") self.group = group self.max_size = max_size self.rank = rank self.world_size = world_size self.full_nvlink = full_nvlink + super().__init__(world_size=world_size) if not _is_hip: # Buffers memory are owned by this Python class and passed to C++. @@ -135,8 +136,6 @@ def __init__( ) self.register_buffer(self.buffer) - self.disabled = False - self.original_disabled = False # Ensure original_disabled == disabled self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() @staticmethod @@ -189,6 +188,12 @@ def capture(self): if not self.disabled: self.register_graph_buffers() + def should_use_custom_op(self) -> bool: + return True + + def graph_capture_context(self): + return self.capture() + def _get_ipc_meta(self, inp: torch.Tensor): # _share_cuda_() doesn't accept meta buffer not allocated from # PyTorch cache allocator, use direct HIP call to get IPC handle @@ -253,30 +258,30 @@ def register_graph_buffers(self): offsets = [d[1] for d in all_data] # type: ignore ops.register_graph_buffers(self._ptr, handles, offsets) - def should_custom_ar(self, inp: torch.Tensor): + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: if self.disabled: - return False - inp_size = inp.numel() * inp.element_size() + return None + inp_size = input_.numel() * input_.element_size() # custom allreduce requires input byte size to be multiples of 16 if inp_size % 16 != 0: - return False - if not is_weak_contiguous(inp): - return False + return None + if not is_weak_contiguous(input_): + return None # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. if not _is_hip: if self.world_size == 2 or self.full_nvlink: - return inp_size <= self.max_size - return False + return AllReduceMode.OUTPLACE if inp_size <= self.max_size else None + return None if _is_hip: if self.use_amd_deterministic_impl: - return True + return AllReduceMode.OUTPLACE if self.full_nvlink: - return inp_size <= self.max_size - return False + return AllReduceMode.OUTPLACE if inp_size <= self.max_size else None + return None - return False + return None def _all_reduce_impl(self, inp: torch.Tensor, registered: bool): out = torch.empty_like(inp) @@ -302,35 +307,42 @@ def _all_reduce_impl(self, inp: torch.Tensor, registered: bool): ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out - def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: - """The main allreduce API that provides support for cuda graph.""" - # When custom allreduce is disabled, this will be None. - if self.disabled or not self.should_custom_ar(input): - return None + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + """The main all-reduce API with CUDA-graph-aware behavior.""" + self.assert_outplace("all_reduce", inplace) + if self.can_all_reduce(input_) is None: + raise ValueError("custom_all_reduce cannot handle this input") if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self._all_reduce_impl(input, registered=not self.tms_cudagraph) + return self._all_reduce_impl(input_, registered=not self.tms_cudagraph) else: # Could be warmup OR piecewise cuda graph split op execution. # In piecewise cuda graph, split ops run eagerly outside the graph # but _IS_CAPTURING is still True. We need to do real all-reduce. if is_in_piecewise_cuda_graph(): # Split op execution - do real all-reduce - return self._all_reduce_impl(input, registered=False) + return self._all_reduce_impl(input_, registered=False) else: # True warmup - mimic the allocation pattern since custom # allreduce is out-of-place. - return torch.zeros_like(input) + return torch.zeros_like(input_) else: - return self._all_reduce_impl(input, registered=False) + return self._all_reduce_impl(input_, registered=False) def close(self): - if not self.disabled and self._ptr: + if not getattr(self, "_disabled", True) and getattr(self, "_ptr", 0): ops.dispose(self._ptr) if _is_cuda: self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.buffer_ptrs) self._ptr = 0 + self._disabled = True def __del__(self): self.close() @@ -378,23 +390,16 @@ def dispatch_custom_allreduce(): if get_bool_env_var("SGLANG_USE_AITER_AR", default="true"): try: - from aiter.dist.device_communicators.custom_all_reduce import ( - CustomAllreduce as AiterCustomAllreduce, - ) + from .custom_all_reduce_aiter import AiterCustomAllReduce logger.info("[AR] Using AiterCustomAllreduce (AMD default)") - tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() - return partial( - AiterCustomAllreduce, - enable_register_for_capturing=not tms_cudagraph, - ) + return AiterCustomAllReduce except ImportError as e: logger.warning( "[AR] Aiter custom all-reduce not available; " "falling back to sglang CustomAllreduce. Details: %s", e, ) - return CustomAllreduce return CustomAllreduce diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py new file mode 100644 index 000000000000..4ce4fcb8c06f --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py @@ -0,0 +1,45 @@ +from typing import Optional + +import torch +from aiter.dist.device_communicators.custom_all_reduce import CustomAllreduce +from torch.distributed import ProcessGroup + +from sglang.srt.environ import envs + +from .base import AllReduceMode, BaseCommunicator + + +class AiterCustomAllReduce(BaseCommunicator): + name = "custom_all_reduce_aiter" + + def __init__(self, group: ProcessGroup, *args, **kwargs): + tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() + self.comm = CustomAllreduce( + group, *args, **kwargs, enable_register_for_capturing=not tms_cudagraph + ) + + def graph_capture_context(self): + return self.comm.capture() + + def should_use_custom_op(self) -> bool: + return True + + @property + def disabled(self) -> bool: + return self._disabled or self.comm.disabled + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + can_use = self.comm.should_custom_ar(input_) + return AllReduceMode.OUTPLACE if can_use else None + + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + self.assert_outplace("all_reduce", inplace) + out = self.comm.custom_all_reduce(input_) + assert out is not None + return out diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index b1505177a7dd..4dc70e11ff15 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -18,7 +18,6 @@ from typing_extensions import ParamSpec from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary -from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import is_cuda, is_hip, is_musa logger = logging.getLogger(__name__) @@ -405,6 +404,9 @@ def can_use_custom_all_reduce_with_nvlink( supported_world_size: List[int], cls_name: str, ) -> Optional[bool]: # None if fail; otherwise return whether NVLink is available + # lazy init to avoid circular import + from sglang.srt.distributed.parallel_state import in_the_same_node_as + assert ( dist.get_backend(group) != dist.Backend.NCCL ), f"{cls_name} should be attached to a non-NCCL group." diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py index 2006554f9fd6..32c110a72399 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py @@ -8,12 +8,14 @@ from torch.distributed import ProcessGroup from sglang.jit_kernel.all_reduce import AllReduceAlgo, get_custom_all_reduce_cls -from sglang.srt.distributed import is_in_piecewise_cuda_graph -from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( +from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph +from sglang.srt.utils import is_sm100_supported, log_info_on_rank0 + +from .base import AllReduceMode, BaseCommunicator +from .custom_all_reduce_utils import ( can_use_custom_all_reduce_with_nvlink, is_weak_contiguous, ) -from sglang.srt.utils import is_sm100_supported, log_info_on_rank0 logger = logging.getLogger(__name__) @@ -28,7 +30,9 @@ class ModeConfig: one_shot_pull_threshold: int # below this, use one-shot pull -class CustomAllReduceV2: +class CustomAllReduceV2(BaseCommunicator): + name = "custom_all_reduce_v2" + def __init__( self, group: ProcessGroup, @@ -37,7 +41,6 @@ def __init__( max_push_size: Optional[int] = None, ) -> None: _init_config() - self.disabled = True full_nvlink = can_use_custom_all_reduce_with_nvlink( group=group, device=device, @@ -45,11 +48,12 @@ def __init__( cls_name="CustomAllReduceV2", ) if full_nvlink != True: - return + raise RuntimeError("CustomAllReduceV2 requires full NVLink connectivity") self.group = group self.rank = dist.get_rank(group=self.group) self.world_size = dist.get_world_size(group=self.group) + super().__init__(world_size=self.world_size) self.override_shot(None) if max_pull_size is None: max_pull_size = 16 * 1024 * 1024 # default to 16MB @@ -67,9 +71,14 @@ def __init__( graph_input_count=131072, ) self._post_init_obj() - self.disabled = False log_info_on_rank0(logger, "Custom allreduce v2 initialized successfully") + def graph_capture_context(self): + return self.capture() + + def should_use_custom_op(self) -> bool: + return True + def override_shot(self, shot: int | None): if shot is None: self.config = THRESHOLD_2_SHOT_MAP[self.world_size] @@ -99,30 +108,39 @@ def capture(self): self.obj.register_inputs(result) log_info_on_rank0(logger, f"Registering {len(pairs)} cuda graph addresses") - def should_custom_ar(self, inp: torch.Tensor) -> bool: - """Check if the input tensor is suitable for custom all-reduce.""" + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: if self.disabled: - return False - inp_size = inp.numel() * inp.element_size() + return None + inp_size = input_.numel() * input_.element_size() # custom allreduce requires input byte size to be multiples of 16 - if inp_size % 16 != 0: - return False - if not is_weak_contiguous(inp): - return False - return inp_size <= self.max_pull_size - - def custom_all_reduce(self, input: torch.Tensor) -> torch.Tensor: + if inp_size % 16 != 0 or inp_size > self.max_pull_size: + return None + if not is_weak_contiguous(input_): + return None + return AllReduceMode.OUTPLACE + + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + self.assert_outplace("all_reduce", inplace) + if self.can_all_reduce(input_) is None: + raise ValueError("custom_all_reduce_v2 cannot handle this input") if is_in_piecewise_cuda_graph(): # disable inplace optimization try: self.obj.set_cuda_graph_capture(False) - return self._all_reduce(input) + return self._all_reduce(input_) finally: self.obj.set_cuda_graph_capture(True) - return self._all_reduce(input) + return self._all_reduce(input_) def close(self): - if not self.disabled and hasattr(self, "obj"): + if not getattr(self, "_disabled", True) and hasattr(self, "obj"): self.obj.free(self.group) + self._disabled = True def _all_reduce(self, input: torch.Tensor) -> torch.Tensor: """Perform the actual all-reduce via JIT kernel.""" diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 722e494cf775..42a023bfe9df 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -1,49 +1,53 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py +from typing import Optional + import torch import torch.distributed as dist from torch.distributed import ProcessGroup from sglang.srt.utils import is_hpu +from .base import AllReduceMode, BaseCommunicator + if is_hpu(): import habana_frameworks.torch as htorch # noqa: F401 -class HpuCommunicator: +class HpuCommunicator(BaseCommunicator): + name = "hpu" def __init__(self, group: ProcessGroup): - if not is_hpu(): - self.disabled = True - return - self.disabled = False self.group = group - self.world_size = dist.get_world_size(self.group) - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + super().__init__(dist.get_world_size(self.group), disabled=not is_hpu()) + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + return AllReduceMode.INPLACE + + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + self.assert_inplace("all_reduce", inplace) # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used # (which is required for tensor parallel HPUGraph inference) htorch.core.mark_step() - dist.all_reduce(x, group=self.group) - return x - - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - if dim < 0: - # Convert negative dim to positive. - dim += x.dim() - input_size = x.size() - # Allocate output tensor. - output_tensor = torch.empty( - (world_size,) + input_size, dtype=x.dtype, device=x.device - ) - # All-gather. + dist.all_reduce(input_, group=self.group) + return input_ + + @BaseCommunicator.validate + def all_gather_into_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + out = self.allocate_all_gather(input_) htorch.core.mark_step() - dist.all_gather_into_tensor(output_tensor, x, group=self.group) - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape( - input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] - ) - return output_tensor + dist.all_gather_into_tensor(out, input_, group=self.group) + return out diff --git a/python/sglang/srt/distributed/device_communicators/impl.py b/python/sglang/srt/distributed/device_communicators/impl.py new file mode 100644 index 000000000000..730a53ae816f --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/impl.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import weakref +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, ContextManager, Dict, List, Optional + +import torch + +from sglang.srt.compilation.compilation_config import register_split_op +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from .base import AllReduceMode, BaseCommunicator + + +@dataclass +class CommunicatorImpl: + unique_name: str + world_size: int + capture_comms: List[BaseCommunicator] + all_reduce_comms: List[BaseCommunicator] + all_gather_comms: List[BaseCommunicator] + reduce_scatter_comms: List[BaseCommunicator] + + # NOTE: never use this, this is only kept for compatibility with + # python/sglang/srt/model_executor/mindspore_runner.py + device_group: torch.distributed.ProcessGroup + + def __post_init__(self): + _register_group(self) + + def graph_capture_context(self) -> List[ContextManager]: + ctx_list = [] + for comm in self.capture_comms: + if (comm_ctx := comm.graph_capture_context()) is not None: + ctx_list.append(comm_ctx) + return ctx_list + + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + for i, comm in enumerate(self.all_reduce_comms): + if comm.disabled: + continue + mode = comm.can_all_reduce(input_) + if mode is not None and _is_mode_supported(mode, inplace): + if not comm.should_use_custom_op(): + return comm.all_reduce(input_, inplace=inplace) + if _can_use_inplace(mode): # we prefer in-place if possible + inplace_all_reduce(input_, self.unique_name, i) + return input_ + else: + return outplace_all_reduce(input_, self.unique_name, i) + raise ValueError(f"No compatible all-reduce communicator found: {inplace = }") + + def reduce_scatter_tensor( + self, + input_: torch.Tensor, + *, + out: torch.Tensor, # for now, out is never None + ) -> torch.Tensor: + for i, comm in enumerate(self.reduce_scatter_comms): + if comm.disabled: + continue + if not comm.should_use_custom_op(): + return comm.reduce_scatter_tensor(input_, out=out) + inplace_reduce_scatter(input_, self.unique_name, i, out=out) + return out + raise ValueError("No compatible reduce-scatter communicator found") + + def all_gather_into_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for i, comm in enumerate(self.all_gather_comms): + if comm.disabled: + continue + if not comm.should_use_custom_op(): + return comm.all_gather_into_tensor(input_, out=out) + if out is not None: + inplace_all_gather(input_, self.unique_name, i, out=out) + return out + else: + return outplace_all_gather(input_, self.unique_name, i) + raise ValueError("No compatible all-gather communicator found") + + +# NOTE: never use any of the following functions/variable outside this module +# the only exception is +# python/sglang/srt/model_executor/mindspore_runner.py +# we keep backward compatibility for this file + +_GROUPS: Dict[str, Callable[[], Optional[CommunicatorImpl]]] = {} + + +def _register_group(group: CommunicatorImpl) -> None: + _GROUPS[group.unique_name] = weakref.ref(group) + + +def _get_group(group_name: str) -> CommunicatorImpl: + assert group_name in _GROUPS, f"Group {group_name} is not found." + group = _GROUPS[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group + + +def _fake_all_gather(input: torch.Tensor, group_name: str) -> torch.Tensor: + from sglang.srt.distributed.device_communicators.base import allocate_all_gather + + return allocate_all_gather(input, _get_group(group_name).world_size) + + +@register_custom_op(mutates_args=["input_"]) +@register_split_op() +def inplace_all_reduce(input_: torch.Tensor, group_name: str, method: int) -> None: + group = _get_group(group_name) + group.all_reduce_comms[method].all_reduce(input_, inplace=True) + + +@register_custom_op(out_shape="input_") +def outplace_all_reduce( + input_: torch.Tensor, group_name: str, method: int +) -> torch.Tensor: + group = _get_group(group_name) + return group.all_reduce_comms[method].all_reduce(input_, inplace=False) + + +@register_custom_op(mutates_args=["out"]) +def inplace_reduce_scatter( + input_: torch.Tensor, group_name: str, method: int, *, out: torch.Tensor +) -> None: + group = _get_group(group_name) + group.reduce_scatter_comms[method].reduce_scatter_tensor(input_, out=out) + + +@register_custom_op(mutates_args=["out"], out_shape="input_") +def inplace_all_gather( + input_: torch.Tensor, group_name: str, method: int, *, out: torch.Tensor +) -> None: + group = _get_group(group_name) + group.all_gather_comms[method].all_gather_into_tensor(input_, out=out) + + +@register_custom_op(fake_impl=_fake_all_gather) +def outplace_all_gather( + input_: torch.Tensor, group_name: str, method: int +) -> torch.Tensor: + group = _get_group(group_name) + return group.all_gather_comms[method].all_gather_into_tensor(input_) + + +# NOTE(dark): we don't make them class method due to conflict with piecewise cuda graph + + +def _can_use_inplace(mode: AllReduceMode) -> bool: + return mode.value != "outplace" + + +def _is_mode_supported(mode: AllReduceMode, inplace: Optional[bool]) -> bool: + if inplace is None: + return True + elif inplace: + return mode.value != "outplace" + else: + return mode.value != "inplace" diff --git a/python/sglang/srt/distributed/device_communicators/npu_communicator.py b/python/sglang/srt/distributed/device_communicators/npu_communicator.py deleted file mode 100644 index cb6eb88e39be..000000000000 --- a/python/sglang/srt/distributed/device_communicators/npu_communicator.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from sglang.srt.utils import is_npu - - -class NpuCommunicator: - - def __init__(self, group: ProcessGroup): - if not is_npu(): - self.disabled = True - return - self.disabled = False - self.group = group - self.world_size = dist.get_world_size(self.group) - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - dist.all_reduce(x, group=self.group) - return x - - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - if dim < 0: - # Convert negative dim to positive. - dim += x.dim() - input_size = x.size() - output_size = (input_size[0] * world_size,) + input_size[1:] - # Allocate output tensor. - output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device) - # All-gather. - dist.all_gather_into_tensor(output_tensor, x, group=self.group) - # Reshape - output_tensor = output_tensor.reshape((world_size,) + input_size) - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape( - input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] - ) - return output_tensor diff --git a/python/sglang/srt/distributed/device_communicators/pymscclpp.py b/python/sglang/srt/distributed/device_communicators/pymscclpp.py index e45093c782b2..5acced649896 100644 --- a/python/sglang/srt/distributed/device_communicators/pymscclpp.py +++ b/python/sglang/srt/distributed/device_communicators/pymscclpp.py @@ -2,17 +2,18 @@ import logging import math import os -from contextlib import contextmanager from enum import IntEnum from typing import Optional, Union import torch import torch.distributed as dist -from torch.distributed import ProcessGroup, ReduceOp +from torch.distributed import ProcessGroup import sglang.srt.distributed.device_communicators.custom_all_reduce_ops as ops from sglang.srt.utils import is_hip +from .base import AllReduceMode, BaseCommunicator + logger = logging.getLogger(__name__) _is_hip = is_hip() @@ -88,7 +89,8 @@ def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2): return func_cost_us -class PyMscclppCommunicator: +class PyMscclppCommunicator(BaseCommunicator): + name = "pymscclpp" _SUPPORTED_WORLD_SIZES = [8, 16] _MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB")) _SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16] @@ -111,13 +113,14 @@ def __init__( is bind to a unique device, and all communicators in this group are in the same node. """ - self._IS_CAPTURING = False - self.disabled = True if not ops.IS_MSCCLPP_AR_AVAILABLE: # disable because of missing mscclpp library # e.g. in a non-cuda environment - return + raise RuntimeError( + "PyMscclpp is disabled because the mscclpp library is not found." + "To silence this warning, specify disable_mscclpp=True explicitly." + ) self.group = group @@ -127,31 +130,29 @@ def __init__( rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) - if world_size == 1: - # No need to initialize mscclpp for single GPU case. - return + assert world_size > 1 + # PyMscclpp is enabled only in cuda graph + super().__init__(world_size=world_size, disabled=True) if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES: - logger.warning( + raise ValueError( "PyMscclpp is disabled due to an unsupported world" " size: %d. Supported world sizes: %s. To silence this " "warning, specify disable_mscclpp=True explicitly.", world_size, str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES), ) - return self.ranks = torch.distributed.get_process_group_ranks(group) self.nranks_per_node = torch.cuda.device_count() # for now mscclpp with stride in the communicator is not tested if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1): - logger.warning( + raise ValueError( "PyMscclpp is disabled due to an unsupported group %s." "Please ensure all ranks in the group are consecutive." "To silence this warning, specify disable_mscclpp=True explicitly.", str(self.ranks), ) - return if isinstance(device, int): device = torch.device(f"cuda:{device}") @@ -224,9 +225,6 @@ def __init__( ) self.msg_size2best_config = msg_size2best_config[0] - # PyMscclpp is enabled only in cuda graph - self.disabled = True - def pre_tune_config(self, dtype=torch.bfloat16) -> bool: logger.debug(f"start to pre-tune configs for rank {self.rank}") nthreads_to_try = [256, 512, 1024] @@ -257,46 +255,37 @@ def pre_tune_config(self, dtype=torch.bfloat16) -> bool: f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us" ) - def should_mscclpp_allreduce( - self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM - ) -> bool: - if self.disabled or self._context is None: - return False - if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE: - return False - if not mscclpp_is_weak_contiguous(inp): - return False - # only support sum op - if op != ReduceOp.SUM: - return False - if inp.numel() * inp.element_size() > self.max_bytes: - return False + def graph_capture_context(self): + return self.change_state(enable=True) + + def should_use_custom_op(self) -> bool: return True - def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM): - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - self.graph_input_set.add((tensor.dtype, tensor.numel())) - msg_size = tensor.numel() * tensor.itemsize + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + if self.disabled or self._context is None: + return None + if input_.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE: + return None + if not mscclpp_is_weak_contiguous(input_): + return None + if input_.numel() * input_.element_size() > self.max_bytes: + return None + return AllReduceMode.OUTPLACE + + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + self.assert_outplace("all_reduce", inplace) + if self.can_all_reduce(input_) is None: + raise ValueError("pymscclpp cannot handle this input") + msg_size = input_.numel() * input_.element_size() index = bisect.bisect_left(self.msg_size_for_finetune, msg_size) msg_size_finetune = self.msg_size_for_finetune[index] nthreads, nblocks = self.msg_size2best_config[msg_size_finetune] - result = torch.empty_like(tensor) - ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks) + result = torch.empty_like(input_) + ops.mscclpp_allreduce(self._context, input_, result, nthreads, nblocks) return result - - @contextmanager - def change_state( - self, - enable: Optional[bool] = None, - ): - if enable is None: - # guess a default value when not specified - enable = self.available - - old_disable = self.disabled - self.disabled = not enable - - yield - - self.disabled = old_disable diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index eccbc872e11e..31f87e8a559f 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py import logging -from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -9,7 +8,12 @@ import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp -from sglang.srt.distributed.device_communicators.pynccl_wrapper import ( +from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph +from sglang.srt.distributed.utils import StatelessProcessGroup +from sglang.srt.utils.common import get_current_device_stream_fast + +from .base import AllReduceMode, BaseCommunicator +from .pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, @@ -18,13 +22,12 @@ ncclRedOpTypeEnum, ncclUniqueId, ) -from sglang.srt.distributed.utils import StatelessProcessGroup -from sglang.srt.utils.common import get_current_device_stream_fast logger = logging.getLogger(__name__) -class PyNcclCommunicator: +class PyNcclCommunicator(BaseCommunicator): + name = "pynccl" def __init__( self, @@ -54,26 +57,17 @@ def __init__( else: self.rank = group.rank self.world_size = group.world_size - + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + super().__init__(world_size=self.world_size, disabled=True) self.group = group # if world_size == 1, no need to create communicator - if self.world_size == 1: - self.available = False - self.disabled = True - return - try: - self.nccl = NCCLLibrary(library_path) - except Exception: - # disable because of missing NCCL library - # e.g. in a non-GPU environment - self.available = False - self.disabled = True - return + assert self.world_size > 1 + self.nccl = NCCLLibrary(library_path) self.available = True - self.disabled = False - self.nccl_version = self.nccl.ncclGetRawVersion() if self.rank == 0: logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion()) @@ -105,7 +99,7 @@ def __init__( # nccl communicator and stream will use this device # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one - with torch.cuda.device(device): + with torch.cuda.device(device), self.change_state(enable=True): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank ) @@ -118,72 +112,63 @@ def __init__( warmup_stream.synchronize() del data - # by default it is disabled, e.g. in profiling models and prefill phase. - # to use it, use under `with obj.change_state(enable=True)`, usually - # when we are using CUDA graph. - self.disabled = True + def should_use_custom_op(self) -> bool: + return True + + @property + def disabled(self) -> bool: + # TODO(dark): this is temporary work around + return self._disabled and not is_in_piecewise_cuda_graph() def _resolve_stream(self) -> torch.cuda.Stream: """Return the current device stream used for NCCL calls.""" return get_current_device_stream_fast() - def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM): - if self.disabled: - return + def graph_capture_context(self): + return self.change_state(enable=True) + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + # NOTE: inplace all-reduce is not compatible with piecewise CUDA graph + return AllReduceMode.BOTH + + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" - assert tensor.device == self.device, ( + assert input_.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}" - ) - stream = self._resolve_stream() - self.nccl.ncclAllReduce( - buffer_type(tensor.data_ptr()), - buffer_type(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), - self.comm, - cudaStream_t(stream.cuda_stream), + f"but the input tensor is on {input_.device}" ) - - def outplace_all_reduce( - self, - in_tensor: torch.Tensor, - out_tensor: Optional[torch.Tensor] = None, - op: ReduceOp = ReduceOp.SUM, - ) -> Optional[torch.Tensor]: - if self.disabled: - return None - assert in_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}" - ) - - if out_tensor is None: - out_tensor = torch.empty_like(in_tensor) + output = input_ + if inplace == False: # default to inplace + output = torch.empty_like(input_) stream = self._resolve_stream() self.nccl.ncclAllReduce( - buffer_type(in_tensor.data_ptr()), # sendbuff - buffer_type(out_tensor.data_ptr()), # recvbuff - DIFFERENT pointer - in_tensor.numel(), - ncclDataTypeEnum.from_torch(in_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), + buffer_type(input_.data_ptr()), + buffer_type(output.data_ptr()), + input_.numel(), + ncclDataTypeEnum.from_torch(input_.dtype), + ncclRedOpTypeEnum.from_torch(ReduceOp.SUM), self.comm, cudaStream_t(stream.cuda_stream), ) - return out_tensor + return output - def all_gather( + @BaseCommunicator.validate + def all_gather_impl( self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, sizes: Optional[list[int]] = None, - ): - if self.disabled: - return + stream: Optional[torch.cuda.Stream] = None, + ) -> None: # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" @@ -191,7 +176,8 @@ def all_gather( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - stream = self._resolve_stream() + if stream is None: + stream = self._resolve_stream() if sizes is not None: split_offset = 0 @@ -220,42 +206,38 @@ def all_gather( cudaStream_t(stream.cuda_stream), ) + def all_gather_into_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + out = self.allocate_all_gather(input_) + self.all_gather_impl(out, input_) + return out + def cp_all_gather_into_tensor( self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream: torch.cuda.Stream, - sizes: Optional[list[int]] = None, - ): + ) -> None: """ Currently, it is mainly used in context parallelism, primarily leveraging pynccl to implement non-blocking allgather communication. """ - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert input_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}" - ) - self.nccl.ncclAllGather( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), - input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - self.comm, - cudaStream_t(stream.cuda_stream), - ) + with self.change_state(enable=True): + return self.all_gather_impl(output_tensor, input_tensor, stream=stream) - def reduce_scatter( + @BaseCommunicator.validate + def reduce_scatter_impl( self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, sizes: Optional[list[int]] = None, - ): - if self.disabled: - return + ) -> None: # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" @@ -294,9 +276,19 @@ def reduce_scatter( cudaStream_t(stream.cuda_stream), ) + def reduce_scatter_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + out = self.allocate_reduce_scatter(input_) + self.reduce_scatter_impl(out, input_) + return out + + @BaseCommunicator.validate def send(self, tensor: torch.Tensor, dst: int): - if self.disabled: - return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" @@ -311,9 +303,8 @@ def send(self, tensor: torch.Tensor, dst: int): cudaStream_t(stream.cuda_stream), ) + @BaseCommunicator.validate def recv(self, tensor: torch.Tensor, src: int): - if self.disabled: - return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" @@ -328,9 +319,8 @@ def recv(self, tensor: torch.Tensor, src: int): cudaStream_t(stream.cuda_stream), ) + @BaseCommunicator.validate def broadcast(self, tensor: torch.Tensor, src: int): - if self.disabled: - return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" @@ -365,19 +355,3 @@ def group_start(self): def group_end(self): self.nccl.ncclGroupEnd() - - @contextmanager - def change_state(self, enable: Optional[bool] = None): - """ - A context manager to change the enabled state of the communicator. - """ - if enable is None: - # guess a default value when not specified - enable = self.available - - old_disable = self.disabled - self.disabled = not enable - try: - yield - finally: - self.disabled = old_disable diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_symm.py b/python/sglang/srt/distributed/device_communicators/pynccl_symm.py new file mode 100644 index 000000000000..0e0a937f9f87 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/pynccl_symm.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.layers.dp_attention import is_allocation_symmetric + +from .base import AllReduceMode, BaseCommunicator +from .pynccl_allocator import is_symmetric_memory_enabled, use_symmetric_memory + +if TYPE_CHECKING: + from sglang.srt.distributed import GroupCoordinator + from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator + + +class PyNcclSymmMemCommunicator(BaseCommunicator): + name = "pynccl_symm_mem" + + def __init__(self, group_coordinator: GroupCoordinator, pynccl: PyNcclCommunicator): + self.group_coordinator = group_coordinator + self.pynccl = pynccl + super().__init__(self.pynccl.world_size) + + @property + def disabled(self) -> bool: + return self._disabled or not is_symmetric_memory_enabled() + + def should_use_custom_op(self) -> bool: + return True + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + # always inplace + return AllReduceMode.INPLACE + + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + self.assert_inplace("all_reduce", inplace) + with self.pynccl.change_state(enable=True): + self.pynccl.all_reduce(input_, inplace=True) + return input_ + + @BaseCommunicator.validate + def all_gather_into_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + with self._allocation_context(): + out = self.allocate_all_gather(input_) + with self.pynccl.change_state(enable=True): + self.pynccl.all_gather_into_tensor(input_, out=out) + return out + + @BaseCommunicator.validate + def reduce_scatter_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + with self._allocation_context(): + out = self.allocate_reduce_scatter(input_) + with self.pynccl.change_state(enable=True): + self.pynccl.reduce_scatter_tensor(input_, out=out) + return out + + def _allocation_context(self): + return use_symmetric_memory( + group_coordinator=self.group_coordinator, + disabled=not is_allocation_symmetric(), + ) diff --git a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py index f9d51246ed8c..de2cf603583d 100644 --- a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py @@ -4,20 +4,19 @@ import os from enum import Enum from functools import cache -from typing import Union +from typing import Optional, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup import sglang.srt.distributed.device_communicators.custom_all_reduce_ops as ops -from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( - is_full_nvlink, - is_weak_contiguous, -) from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import is_cuda, is_hip +from .base import AllReduceMode, BaseCommunicator +from .custom_all_reduce_utils import is_full_nvlink, is_weak_contiguous + logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -49,7 +48,8 @@ class QuickReduceRegime(Enum): MB = 1024 * 1024 -class QuickAllReduce: +class QuickAllReduce(BaseCommunicator): + name = "quick_all_reduce" _SUPPORTED_WORLD_SIZES = [2, 4, 8] _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -86,21 +86,13 @@ def __init__( is bind to a unique device, and all communicators in this group are in the same node. """ - self.disabled = True if not qr_rocm_arch_available(): - logger.debug( - "Custom quick allreduce is only supported on ROCm MI300 series." + raise RuntimeError( + "quick all-reduce is only supported on ROCm MI300 series" ) - return if not ops.IS_QUICK_AR_AVAILABLE: - # disable because of missing quick reduce library - # e.g. in a cuda environment - logger.info( - "Custom quick allreduce is disabled because " - "of missing custom quick allreduce library" - ) - return + raise RuntimeError("quick all-reduce library is not available") self.group = group assert ( @@ -109,27 +101,19 @@ def __init__( if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom quick allreduce for # multi-node case. - logger.warning( - "Custom quick allreduce is disabled because this " - "process group spans across nodes." - ) - return + raise RuntimeError("quick all-reduce requires all ranks to be on one node") rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) self.rank = rank self.world_size = world_size - if world_size == 1: - # No need to initialize QuickReduce for single GPU case. - return + assert world_size > 1 + super().__init__(world_size=world_size) if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: - logger.warning( - "Custom quick allreduce is disabled due to an " - "unsupported world size: %d. Supported world sizes: %s.", - world_size, - str(QuickAllReduce._SUPPORTED_WORLD_SIZES), + raise ValueError( + "unsupported quick all-reduce world size: " + f"{world_size}; supported sizes: {QuickAllReduce._SUPPORTED_WORLD_SIZES}" ) - return if isinstance(device, int): device = torch.device(f"cuda:{device}") @@ -158,11 +142,9 @@ def __init__( if _is_cuda or _is_hip: self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size) if self.world_size > 2 and not self.fully_connected: - logger.debug( - "Custom quick allreduce is disabled because it's not supported " - "on more than two PCIe-only GPUs. " + raise RuntimeError( + "quick all-reduce does not support more than two PCIe-only GPUs" ) - return self.init_quick_all_reduce() @@ -175,21 +157,15 @@ def init_quick_all_reduce(self): ) regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE") if regime_str not in QuickReduceRegime.__members__: - logger.warning( - "Custom quick allreduce:", - f"Invalid quantization level: {regime_str}. " - "Supported levels: " - f"{list(QuickReduceRegime.__members__.keys())}", + raise ValueError( + "invalid quick all-reduce quantization level: " + f"{regime_str}; supported levels: {list(QuickReduceRegime.__members__.keys())}" ) - return if regime_str == "NONE": - logger.debug( - "Custom quick allreduce is disabled based " - "on env variable " - "ROCM_QUICK_REDUCE_QUANTIZATION='NONE'" + raise RuntimeError( + "quick all-reduce is disabled by ROCM_QUICK_REDUCE_QUANTIZATION='NONE'" ) - return self.qr_quant_level = QuickReduceRegime[regime_str] # TODO: If the dtype is not bfloat16 or then float16, @@ -208,7 +184,6 @@ def init_quick_all_reduce(self): self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size() self.create_shared_buffer() - self.disabled = False def create_shared_buffer(self): """ @@ -221,47 +196,54 @@ def create_shared_buffer(self): dist.all_gather_object(handles, handle, group=self.group) ops.qr_open_handles(self._ptr, handles) - def should_quick_allreduce(self, inp: torch.Tensor): - """ - Check if quickreduce is available - """ + def should_use_custom_op(self) -> bool: + return True + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: if self.disabled: - return False - if inp.dtype not in self._SUPPORTED_DTYPES: - return False - inp_size = inp.numel() * inp.element_size() + return None + if input_.dtype not in self._SUPPORTED_DTYPES: + return None + inp_size = input_.numel() * input_.element_size() # custom quick allreduce requires input byte size to be # multiples of 16 if inp_size % 16 != 0: - return False - if not is_weak_contiguous(inp): - return False - dtype = inp.dtype + return None + if not is_weak_contiguous(input_): + return None + dtype = input_.dtype if self.use_fp16_kernels: dtype = torch.float16 - return ( + enabled = ( inp_size <= self.qr_max_size and inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value] ) - - def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): - """Performs an out-of-place custom quick all reduce.""" + return AllReduceMode.OUTPLACE if enabled else None + + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + """Perform an out-of-place quick all-reduce.""" # quick allreduce doesn't require a separate graph mode, # as QR uses static IPC buffer. - if out is None: - out = torch.empty_like(inp) + self.assert_outplace("all_reduce", inplace) + out = torch.empty_like(input_) ops.qr_all_reduce( - self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels + self._ptr, input_, out, self.qr_quant_level.value, self.use_fp16_kernels ) return out def close(self): - if not self.disabled and getattr(self, "_ptr", None): + if not getattr(self, "_disabled", True) and getattr(self, "_ptr", None): if ops is not None: ops.qr_destroy(self._ptr) self._ptr = 0 - self.disabled = True + self._disabled = True def __del__(self): self.close() diff --git a/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py b/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py index fb1e1715ae5d..fafc1cca1c0f 100644 --- a/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py +++ b/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py @@ -6,11 +6,10 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from sglang.srt.distributed.device_communicators.all_reduce_utils import ( - TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES, -) from sglang.srt.utils import is_cuda, is_hip +from .base import AllReduceMode, BaseCommunicator + try: import torch.distributed._symmetric_memory as torch_symm_mem @@ -27,7 +26,7 @@ logger = logging.getLogger(__name__) -class TorchSymmMemCommunicator: +class TorchSymmMemCommunicator(BaseCommunicator): """ Thin wrapper around torch-symmetric-memory collectives. @@ -41,6 +40,8 @@ class TorchSymmMemCommunicator: decline to perform symmetric-memory all-reduce. """ + name = "torch_symm_mem" + # Mapping: compute capability major -> supported world sizes for multimem # If the current (cc_major, world_size) is not listed, we fall back # to the two-shot path. @@ -56,10 +57,12 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]): device: Target CUDA device (index, 'cuda:X', or torch.device). """ - self.disabled = True - if not torch_symm_mem_available: - return + raise RuntimeError( + "TorchSymmMemCommunicator requires torch symmetric memory support, " + "but it is not available. Ensure you have the correct PyTorch version " + "and that your hardware supports it." + ) if isinstance(device, int): device = torch.device(f"cuda:{device}") @@ -72,22 +75,18 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]): self.world_size = dist.get_world_size(self.group) self.device_capability = torch.cuda.get_device_capability(device)[0] if self.device_capability < 9: - logger.warning( + raise RuntimeError( "TorchSymmMemCommunicator: Device capability %s not supported, " - "communicator is not available.", - self.device_capability, + "communicator is not available.".format(self.device_capability) ) - return if ( self.world_size not in TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability] ): - logger.warning( + raise RuntimeError( "TorchSymmMemCommunicator: World size %d not supported, " - "communicator is not available.", - self.world_size, + "communicator is not available.".format(self.world_size) ) - return self.max_size = TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ self.world_size ] @@ -98,16 +97,17 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]): ) handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) if handle.multicast_ptr == 0: - logger.warning( + self.buffer = None + raise RuntimeError( "TorchSymmMemCommunicator: torch symmetric memory " "multicast operations are not supported." ) - self.buffer = None - self.disabled = True - return - self.disabled = False + super().__init__(world_size=self.world_size) + + def should_use_custom_op(self) -> bool: + return True - def should_torch_symm_mem_allreduce(self, inp: torch.Tensor): + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: """ Fast-path eligibility check for a given tensor. @@ -118,46 +118,70 @@ def should_torch_symm_mem_allreduce(self, inp: torch.Tensor): - Payload must be smaller than the symmetric-memory max size. Returns: - True if the symmetric-memory path can handle this tensor. + `AllReduceMode.OUTPLACE` if the symmetric-memory path can handle + this tensor. """ if self.disabled: - return False - if inp.dtype != self.dtype: - return False - inp_size = inp.numel() * inp.element_size() + return None + if input_.dtype != self.dtype: + return None + inp_size = input_.numel() * input_.element_size() # enforce 4-byte alignment if inp_size % 4 != 0: - return False - return inp_size < self.max_size + return None + return AllReduceMode.OUTPLACE if inp_size < self.max_size else None + @BaseCommunicator.validate def all_reduce( - self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None - ) -> Optional[torch.Tensor]: + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: """ - Perform an in-place sum all-reduce via torch symmetric memory. + Perform an out-of-place sum all-reduce via torch symmetric memory. Args: - inp: Input tensor on the target CUDA device (bfloat16). - out: Optional output tensor; if omitted, a new tensor is allocated. + input_: Input tensor on the target CUDA device (bfloat16). + inplace: Must be `False` or `None`. Returns: - The reduced tensor (same shape as inp), or None if disabled. + The reduced tensor. Implementation details: - - Stages 'inp' into the symmetric buffer. + - Stages 'input_' into the symmetric buffer. - Selects 'multimem' or 'two_shot' kernel based on topology. - - Writes the result into 'out' and returns it. + - Copies the result into a newly allocated output tensor. """ - if out is None: - out = torch.empty_like(inp) - self.buffer[: inp.numel()].copy_(inp.view(-1)) + self.assert_outplace("all_reduce", inplace) + assert self.buffer is not None, "Symmetric buffer not initialized" + out = torch.empty_like(input_) + self.buffer[: input_.numel()].copy_(input_.view(-1)) if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]: torch.ops.symm_mem.multimem_all_reduce_( - self.buffer[: inp.numel()], "sum", self.group.group_name + self.buffer[: input_.numel()], "sum", self.group.group_name ) else: torch.ops.symm_mem.two_shot_all_reduce_( - self.buffer[: inp.numel()], "sum", self.group.group_name + self.buffer[: input_.numel()], "sum", self.group.group_name ) - out.copy_(self.buffer[: inp.numel()].view(out.shape)) + out.copy_(self.buffer[: input_.numel()].view(out.shape)) return out + + +MiB = 1024 * 1024 + +TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + 9: { + 2: 64 * MiB, # 64 MB + 4: 64 * MiB, # 64 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + }, + 10: { + 2: 64 * MiB, # 64 MB + 4: 64 * MiB, # 64 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + }, +} diff --git a/python/sglang/srt/distributed/device_communicators/torch_wrapper.py b/python/sglang/srt/distributed/device_communicators/torch_wrapper.py new file mode 100644 index 000000000000..6ab368c7f727 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/torch_wrapper.py @@ -0,0 +1,95 @@ +from typing import List, Optional + +import torch +import torch.distributed as dist + +from .base import AllReduceMode, BaseCommunicator + + +class TorchDefaultCommunicator(BaseCommunicator): + name = "torch_native" + + def __init__( + self, + rank_in_group: int, + ranks: List[int], + device_group: dist.ProcessGroup, + ) -> None: + self.ranks = ranks + self.rank_in_group = rank_in_group + self.device_group = device_group + super().__init__(len(ranks)) + + def change_state(self, enable: bool): + assert enable, "TorchDefaultCommunicator cannot be disabled" + return super().change_state(enable) + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + return AllReduceMode.INPLACE + + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + self.assert_inplace("all_reduce", inplace) + dist.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather_into_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + out = self.allocate_all_gather(input_) + dist.all_gather_into_tensor(out, input_, group=self.device_group) + return out + + def all_gather( + self, + input_: torch.Tensor, + *, + out_list: Optional[List[torch.Tensor]] = None, + ) -> List[torch.Tensor]: + if out_list is None: + out_list = [torch.empty_like(input_) for _ in range(self.world_size)] + dist.all_gather(out_list, input_, group=self.device_group) + return out_list + + def reduce_scatter_tensor( + self, + input_: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + out = self.allocate_reduce_scatter(input_) + dist.reduce_scatter_tensor(out, input_, group=self.device_group) + return out + + def reduce_scatter( + self, + input_list: List[torch.Tensor], + *, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if out is None: + out = torch.empty_like(input_list[self.rank_in_group]) + dist.reduce_scatter(out, input_list, group=self.device_group) + return out + + def gather( + self, + input_: torch.Tensor, + dst: int, + *, + dim: int = 0, + ) -> Optional[torch.Tensor]: + gather_list = None + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(self.world_size)] + dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.device_group) + return None if gather_list is None else torch.cat(gather_list, dim=dim) diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py index 532279f70c35..78c7166f66b7 100644 --- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -1,29 +1,46 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py +from typing import Optional + import torch import torch.distributed as dist from torch.distributed import ProcessGroup from sglang.srt.utils import is_xpu +from .base import AllReduceMode, BaseCommunicator + -class XpuCommunicator: +class XpuCommunicator(BaseCommunicator): + name = "xpu" - def __init__(self, group: ProcessGroup): - if not is_xpu(): - self.disabled = True - return - self.disabled = False + def __init__(self, rank_in_group: int, group: ProcessGroup): + self.rank_in_group = rank_in_group self.group = group - self.world_size = dist.get_world_size(self.group) + super().__init__(dist.get_world_size(group), disabled=not is_xpu()) + + def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + return AllReduceMode.INPLACE - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - dist.all_reduce(x, group=self.group) - return x + @BaseCommunicator.validate + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: + self.assert_inplace("all_reduce", inplace) + dist.all_reduce(input_, group=self.group) + return input_ + @BaseCommunicator.validate def gather( - self, input_: torch.Tensor, rank_in_group: int, dst: int = 0, dim: int = -1 - ): + self, + input_: torch.Tensor, + dst: int, + *, + dim: int = 0, + ) -> Optional[torch.Tensor]: # For xpu path, gather doesn't work properly together with ray # cluster so we use all_gather instead for now. input_size = input_.size() @@ -35,7 +52,7 @@ def gather( torch.distributed.all_gather_into_tensor( output_tensor, input_, group=self.group ) - if rank_in_group == dst: + if self.rank_in_group == dst: # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape( diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index b6f3987b7ed6..fb3087dbce9c 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -22,26 +22,26 @@ steps. """ +from __future__ import annotations + import contextlib import gc import logging import os import pickle -import weakref from collections import namedtuple -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -from sglang.srt.compilation.compilation_config import register_split_op -from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph +from sglang.srt.compilation.compile import is_in_piecewise_cuda_graph from sglang.srt.distributed.utils import set_global_tcp_store from sglang.srt.environ import envs from sglang.srt.utils import ( @@ -51,14 +51,15 @@ is_cpu, is_cuda_alike, is_hip, + is_hpu, is_musa, is_npu, is_shm_available, is_xpu, ) -from sglang.srt.utils.custom_op import register_custom_op from sglang.srt.utils.network import get_local_ip_auto +_is_hpu = is_hpu() _is_npu = is_npu() _is_cpu = is_cpu() _is_xpu = is_xpu() @@ -141,54 +142,12 @@ def _get_unique_name(name: str) -> str: return newname -_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} - - -def _register_group(group: "GroupCoordinator") -> None: - _groups[group.unique_name] = weakref.ref(group) - - -@register_custom_op(mutates_args=["tensor"]) -@register_split_op() -def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce_in_place(tensor) - - -@register_custom_op(out_shape="tensor") -def outplace_all_reduce( - tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str -) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor, outplace_all_reduce_method) - - -@register_custom_op(mutates_args=["output"]) -def reg_all_gather_into_tensor( - output: torch.Tensor, input: torch.Tensor, group_name: str -) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_gather_into_tensor(output, input) - - -@register_custom_op(mutates_args=["output"]) -def reg_reduce_scatter_tensor( - output: torch.Tensor, input: torch.Tensor, group_name: str -) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._reduce_scatter_tensor(output, input) +@contextmanager +def _init_communicator(name: str): + try: + yield + except Exception as e: + logger.warning(f"Failed to initialize {name} with error: {e}.") class GroupCoordinator: @@ -202,36 +161,6 @@ class GroupCoordinator: based on the tensor size and cuda graph mode). """ - # available attributes: - rank: int # global rank - ranks: List[int] # global ranks in the group - world_size: int # size of the group - # difference between `local_rank` and `rank_in_group`: - # if we have a group of size 4 across two nodes: - # Process | Node | Rank | Local Rank | Rank in Group - # 0 | 0 | 0 | 0 | 0 - # 1 | 0 | 1 | 1 | 1 - # 2 | 1 | 2 | 0 | 2 - # 3 | 1 | 3 | 1 | 3 - local_rank: int # local rank used to assign devices - rank_in_group: int # rank inside the group - cpu_group: ProcessGroup # group for CPU communication - device_group: ProcessGroup # group for device communication - use_pynccl: bool # a hint of whether to use PyNccl - use_pymscclpp: bool # a hint of whether to use PyMsccl - use_custom_allreduce: bool # a hint of whether to use CustomAllreduce - use_torch_symm_mem_all_reduce: ( - bool # a hint of whether to use TorchSymmMemAllReduce - ) - use_message_queue_broadcaster: ( - bool # a hint of whether to use message queue broadcaster - ) - # communicators are only created for world size > 1 - pynccl_comm: Optional[Any] # PyNccl communicator - ca_comm: Optional[Any] # Custom allreduce communicator - torch_symm_mem_comm: Optional[Any] # Torch symm mem communicator - mq_broadcaster: Optional[Any] # shared memory broadcaster - def __init__( self, group_ranks: List[List[int]], @@ -247,18 +176,28 @@ def __init__( use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, gloo_timeout: timedelta = timedelta(seconds=120 * 60), - ): + ) -> None: # Set group info group_name = group_name or "anonymous" self.unique_name = _get_unique_name(group_name) - _register_group(self) - # Set rank info - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - self.cpu_group = None - self.local_size = get_int_env_var("LOCAL_SIZE", 0) + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + # global rank + self.rank: int = torch.distributed.get_rank() + # local rank used to assign devices + self.local_rank: int = local_rank + # group for device communication + self.device_group: ProcessGroup = None + # group for CPU communication + self.cpu_group: ProcessGroup = None + # for CPU shm communication only + self.local_size: int = get_int_env_var("LOCAL_SIZE", 0) if is_cuda_alike(): device_id = ( @@ -324,119 +263,129 @@ def __init__( self.use_message_queue_broadcaster = use_message_queue_broadcaster # Lazy import to avoid documentation build error - from sglang.srt.distributed.device_communicators.custom_all_reduce import ( - dispatch_custom_allreduce, - ) - from sglang.srt.distributed.device_communicators.pymscclpp import ( - PyMscclppCommunicator, - ) - from sglang.srt.distributed.device_communicators.pynccl import ( - PyNcclCommunicator, - ) + import sglang.srt.distributed.device_communicators as comm from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - is_symmetric_memory_enabled, use_symmetric_memory, ) - from sglang.srt.distributed.device_communicators.torch_symm_mem import ( - TorchSymmMemCommunicator, - ) - from sglang.srt.layers.dp_attention import is_allocation_symmetric - self.is_symmetric_memory_enabled = is_symmetric_memory_enabled + self.torch_comm = comm.TorchDefaultCommunicator( + rank_in_group=self.rank_in_group, + ranks=self.ranks, + device_group=self.device_group, + ) self.use_symmetric_memory = use_symmetric_memory - self.is_allocation_symmetric = is_allocation_symmetric - if is_hip(): - from sglang.srt.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce, - qr_rocm_arch_available, - ) - - self.pynccl_comm: Optional[PyNcclCommunicator] = None + self.pynccl_comm = None if use_pynccl and self.world_size > 1: - self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, - device=self.device, - ) + with _init_communicator("PyNcclCommunicator"): + self.pynccl_comm = comm.PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.pynccl_symm_comm = None + if self.pynccl_comm is not None: + with _init_communicator("PyNcclSymmMemCommunicator"): + self.pynccl_symm_comm = comm.PyNcclSymmMemCommunicator( + group_coordinator=self, + pynccl=self.pynccl_comm, + ) - self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None + self.pymscclpp_comm = None if use_pymscclpp and self.world_size > 1: - self.pymscclpp_comm = PyMscclppCommunicator( - group=self.cpu_group, - device=self.device, - ) + with _init_communicator("PyMscclppCommunicator"): + self.pymscclpp_comm = comm.PyMscclppCommunicator( + group=self.cpu_group, + device=self.device, + ) - self.ca_comm: Optional[Any] = None - self.qr_comm: Optional[QuickAllReduce] = None + self.ca_comm = None + self.qr_comm = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. - try: - CAClass = dispatch_custom_allreduce() - self.ca_comm = CAClass( + with _init_communicator("CustomAllReduce"): + self.ca_comm = comm.dispatch_custom_allreduce()( group=self.cpu_group, device=self.device, ) - except Exception as e: - logger.warning( - f"Setup Custom allreduce failed with {e}. To silence this " - "warning, specify --disable-custom-all-reduce explicitly." - ) - if is_hip(): - try: + if is_hip() and comm.qr_rocm_arch_available(): + with _init_communicator("QuickAllReduce"): # Initialize a custom quick all-reduce implementation for AMD # when rocm >= gfx942. Quick reduce is designed as a # complement to custom allreduce. # Based on quickreduce (https://github.com/mk1-project/quickreduce). - if qr_rocm_arch_available(): - self.qr_comm = QuickAllReduce( - group=self.cpu_group, device=self.device - ) - except Exception as e: - logger.warning(f"Failed to initialize QuickAllReduce: {e}") + self.qr_comm = comm.QuickAllReduce( + group=self.cpu_group, device=self.device + ) + elif self.world_size > 1 and is_hip(): logger.info("[AR] All-reduce call path: NCCL (custom AR disabled)") - self.torch_symm_mem_comm: Optional[TorchSymmMemCommunicator] = None + self.torch_symm_mem_comm = None if self.use_torch_symm_mem_all_reduce and self.world_size > 1: - self.torch_symm_mem_comm = TorchSymmMemCommunicator( - group=self.cpu_group, - device=self.device, - ) - - # Create communicator for other hardware backends - from sglang.srt.distributed.device_communicators.hpu_communicator import ( - HpuCommunicator, - ) - from sglang.srt.distributed.device_communicators.npu_communicator import ( - NpuCommunicator, - ) - from sglang.srt.distributed.device_communicators.xpu_communicator import ( - XpuCommunicator, - ) + with _init_communicator("TorchSymmMemCommunicator"): + self.torch_symm_mem_comm = comm.TorchSymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) - self.hpu_communicator: Optional[HpuCommunicator] = None - if use_hpu_communicator and self.world_size > 1: - self.hpu_communicator = HpuCommunicator(group=self.device_group) + self.hpu_communicator = None + if _is_hpu and use_hpu_communicator and self.world_size > 1: + with _init_communicator("HpuCommunicator"): + self.hpu_communicator = comm.HpuCommunicator(group=self.device_group) + + self.xpu_communicator = None + if _is_xpu and use_xpu_communicator and self.world_size > 1: + with _init_communicator("XpuCommunicator"): + self.xpu_communicator = comm.XpuCommunicator( + rank_in_group=self.rank_in_group, + group=self.device_group, + ) - self.xpu_communicator: Optional[XpuCommunicator] = None - if use_xpu_communicator and self.world_size > 1: - self.xpu_communicator = XpuCommunicator(group=self.device_group) + # Create message queue + self.mq_broadcaster = None + if use_message_queue_broadcaster and self.world_size > 1: + with _init_communicator("MessageQueueBroadcaster"): + self.mq_broadcaster = comm.MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6 + ) - self.npu_communicator: Optional[NpuCommunicator] = None - if use_npu_communicator and self.world_size > 1: - self.npu_communicator = NpuCommunicator(group=self.device_group) + def _filter(*items) -> List: + return [x for x in items if x is not None] - # Create message queue - from sglang.srt.distributed.device_communicators.shm_broadcast import ( - MessageQueue, + self.impl = comm.CommunicatorImpl( + unique_name=self.unique_name, + world_size=self.world_size, + capture_comms=_filter( + self.ca_comm, + self.pynccl_comm, + self.pymscclpp_comm, + ), + all_reduce_comms=_filter( + self.pynccl_symm_comm, + self.xpu_communicator, + self.hpu_communicator, + self.ca_comm, + self.qr_comm, + self.pymscclpp_comm, + self.torch_symm_mem_comm, + self.pynccl_comm, + self.torch_comm, + ), + all_gather_comms=_filter( + self.pynccl_symm_comm, + self.pynccl_comm, + self.hpu_communicator, + self.torch_comm, + ), + reduce_scatter_comms=_filter( + self.pynccl_symm_comm, + self.pynccl_comm, + self.torch_comm, + ), + device_group=self.device_group, ) - self.mq_broadcaster: Optional[MessageQueue] = None - if use_message_queue_broadcaster and self.world_size > 1: - self.mq_broadcaster = MessageQueue.create_from_process_group( - self.cpu_group, 1 << 22, 6 - ) - def __repr__(self): return ( f"ranks={self.ranks} rank={self.rank} local_rank={self.local_rank} use_pynccl={self.use_pynccl} " @@ -490,10 +439,6 @@ def graph_capture( graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream - # We don't need the context of custom quick allreduce because the ipc access - # is already collected in init() and we can capture the quick allreduce directly. - ca_comm = self.ca_comm - maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() # ensure all initialization operations complete before attempting to # capture the graph on another stream @@ -501,47 +446,17 @@ def graph_capture( if curr_stream != stream: stream.wait_stream(curr_stream) - with self.device_module.stream(stream), maybe_ca_context: - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # quick allreduce | enabled | enabled | - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # PyMscclpp | disabled| enabled | - # TorchSymmMem | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note: When custom quick allreduce is enabled, a runtime check - # will be performed. If the tensor size is too small, it will - # automatically fall back to the next available option. - # Note that custom allreduce will have a runtime check, if the - # tensor size is too large, it will fallback to the next - # available option. - # Note that the PyMsccl needs to register the tensor in ahead, - # which will introduce large overhead in the eager case, - # therefore it is only supported in the graph case. - # In summary: We select the appropriate allreduce method for - # each mode based on the algorithm order in the table and - # their usage conditions. - pynccl_comm = self.pynccl_comm - maybe_pynccl_context: Any - if not pynccl_comm: - maybe_pynccl_context = nullcontext() - else: - maybe_pynccl_context = pynccl_comm.change_state(enable=True) + with self.device_module.stream(stream), contextlib.ExitStack() as stack: + for comm_context in self.impl.graph_capture_context(): + stack.enter_context(comm_context) + yield graph_capture_context - pymscclpp_comm = self.pymscclpp_comm - maybe_pymscclpp_context: Any - if not pymscclpp_comm: - maybe_pymscclpp_context = nullcontext() - else: - maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True) - with maybe_pynccl_context, maybe_pymscclpp_context: - yield graph_capture_context - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + def all_reduce( + self, + input_: torch.Tensor, + *, + inplace: Optional[bool] = None, + ) -> torch.Tensor: """ User-facing all-reduce function before we actually call the all-reduce operation. @@ -567,57 +482,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: torch.distributed.all_reduce(input_, group=self.device_group) return input_ - if self.hpu_communicator is not None and not self.hpu_communicator.disabled: - return self.hpu_communicator.all_reduce(input_) - - if self.xpu_communicator is not None and not self.xpu_communicator.disabled: - return self.xpu_communicator.all_reduce(input_) - - if self.npu_communicator is not None and not self.npu_communicator.disabled: - return self.npu_communicator.all_reduce(input_) - - if self.pynccl_comm is not None and self.is_symmetric_memory_enabled(): - with self.pynccl_comm.change_state(enable=True): - self.pynccl_comm.all_reduce(input_) - return input_ - - outplace_all_reduce_method = None - if ( - self.ca_comm is not None - and not self.ca_comm.disabled - and self.ca_comm.should_custom_ar(input_) - ): - outplace_all_reduce_method = "ca" - elif ( - self.qr_comm is not None - and not self.qr_comm.disabled - and self.qr_comm.should_quick_allreduce(input_) - ): - outplace_all_reduce_method = "qr" - elif ( - self.pymscclpp_comm is not None - and not self.pymscclpp_comm.disabled - and self.pymscclpp_comm.should_mscclpp_allreduce(input_) - ): - outplace_all_reduce_method = "pymscclpp" - elif ( - self.torch_symm_mem_comm is not None - and not self.torch_symm_mem_comm.disabled - and self.torch_symm_mem_comm.should_torch_symm_mem_allreduce(input_) - ): - outplace_all_reduce_method = "torch_symm_mem" - elif is_in_piecewise_cuda_graph(): - # For piecewise cuda graph, we use pynccl outplace allreduce - outplace_all_reduce_method = "pynccl" - if outplace_all_reduce_method is not None: - return outplace_all_reduce( - input_, - group_name=self.unique_name, - outplace_all_reduce_method=outplace_all_reduce_method, - ) - else: - inplace_all_reduce(input_, group_name=self.unique_name) - return input_ + if inplace is None and is_in_piecewise_cuda_graph(): + inplace = False + return self.impl.all_reduce(input_, inplace=inplace) def fused_allreduce_rmsnorm( self, @@ -671,74 +538,16 @@ def fused_allreduce_rmsnorm( ) return fused_outputs - def _all_reduce_out_place( - self, input_: torch.Tensor, outplace_all_reduce_method: str - ) -> torch.Tensor: - ca_comm = self.ca_comm - qr_comm = self.qr_comm - pymscclpp_comm = self.pymscclpp_comm - torch_symm_mem_comm = self.torch_symm_mem_comm - pynccl_comm = self.pynccl_comm - assert any([qr_comm, ca_comm, pymscclpp_comm, torch_symm_mem_comm, pynccl_comm]) - if outplace_all_reduce_method == "ca": - assert not ca_comm.disabled - out = ca_comm.custom_all_reduce(input_) - elif outplace_all_reduce_method == "qr": - assert not qr_comm.disabled - out = qr_comm.quick_all_reduce(input_) - elif outplace_all_reduce_method == "torch_symm_mem": - assert not torch_symm_mem_comm.disabled - out = torch_symm_mem_comm.all_reduce(input_) - elif outplace_all_reduce_method == "pymscclpp": - assert not pymscclpp_comm.disabled - out = pymscclpp_comm.all_reduce(input_) - elif outplace_all_reduce_method == "pynccl": - with pynccl_comm.change_state(enable=True): - out = pynccl_comm.outplace_all_reduce(input_) - assert out is not None - return out - - def _all_reduce_in_place(self, input_: torch.Tensor) -> None: - pynccl_comm = self.pynccl_comm - torch_symm_mem_comm = self.torch_symm_mem_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.all_reduce(input_) - elif torch_symm_mem_comm is not None and not torch_symm_mem_comm.disabled: - torch_symm_mem_comm.all_reduce(input_) - else: - torch.distributed.all_reduce(input_, group=self.device_group) - - def _reduce_scatter_tensor( - self, - output: torch.Tensor, - input: torch.Tensor, - ) -> torch.Tensor: - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and ( - not pynccl_comm.disabled or self.is_symmetric_memory_enabled() - ): - with pynccl_comm.change_state(enable=True): - pynccl_comm.reduce_scatter(output, input) - else: - torch.distributed.reduce_scatter_tensor( - output, input, group=self.device_group - ) - return output - def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor): - if _is_npu: - self._reduce_scatter_tensor(output, input) - else: - reg_reduce_scatter_tensor(output, input, group_name=self.unique_name) + self.impl.reduce_scatter_tensor(input, out=output) def reduce_scatter( self, output: torch.Tensor, input_list: List[torch.Tensor], - ) -> None: + ) -> torch.Tensor: # TODO(ch-wan): support other backends - torch.distributed.reduce_scatter(output, input_list, group=self.device_group) - return output + return self.torch_comm.reduce_scatter(input_list, out=output) def reduce_scatterv( self, @@ -748,6 +557,7 @@ def reduce_scatterv( ) -> torch.Tensor: world_size = self.world_size pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None, "pynccl is required for reduce_scatterv" with pynccl_comm.change_state(enable=True): assert ( @@ -770,26 +580,15 @@ def reduce_scatterv( else: assert output.shape == output_shape - pynccl_comm.reduce_scatter(output, input_, sizes=sizes) + pynccl_comm.reduce_scatter_impl(output, input_, sizes=sizes) return output - def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and ( - not pynccl_comm.disabled or self.is_symmetric_memory_enabled() - ): - with pynccl_comm.change_state(enable=True): - pynccl_comm.all_gather(output, input) - else: - torch.distributed.all_gather_into_tensor( - output, input, group=self.device_group - ) - - def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): - if _is_npu or _is_xpu: - self._all_gather_into_tensor(output, input) - else: - reg_all_gather_into_tensor(output, input, group_name=self.unique_name) + def all_gather_into_tensor( + self, + output: torch.Tensor, + input: torch.Tensor, + ) -> torch.Tensor: + return self.impl.all_gather_into_tensor(input, out=output) def cp_all_gather_into_tensor_async( self, output: torch.Tensor, input: torch.Tensor, stream: torch.cuda.Stream @@ -806,73 +605,40 @@ def cp_all_gather_into_tensor_async( else: pynccl_comm.cp_all_gather_into_tensor(output, input, stream=stream) - def all_gather( + def all_gather_list( self, input_: torch.Tensor, - dim: int = -1, - output_tensor_list: Optional[List[torch.Tensor]] = None, - ) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - if output_tensor_list is not None: - logger.warning( - "Performing in-place all-gather with a group size of 1. " - "This may be unnecessary; consider bypassing it for better efficiency." - ) - output_tensor_list[0].copy_(input_) - return None - else: - return input_ - - if output_tensor_list is not None: - # TODO(ch-wan): support other backends - return torch.distributed.all_gather( - output_tensor_list, input_, group=self.device_group + output_tensor_list: List[torch.Tensor], + ) -> Optional[List[torch.Tensor]]: + if self.world_size == 1: + logger.warning( + "Performing in-place all-gather with a group size of 1. " + "This may be unnecessary; consider bypassing it for better efficiency." ) + output_tensor_list[0].copy_(input_) + return None + return self.torch_comm.all_gather(input_, out_list=output_tensor_list) - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - - # For HPUs, use HPU communicator. - hpu_comm = self.hpu_communicator - if hpu_comm is not None and not hpu_comm.disabled: - return hpu_comm.all_gather(input_, dim) - - # For NPUs, use NPU communicator. - npu_comm = self.npu_communicator - if npu_comm is not None and not npu_comm.disabled: - return npu_comm.all_gather(input_, dim) + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + # Bypass the function if we are using only 1 GPU. + world_size = self.world_size + if world_size == 1: + return input_ - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() + dim = self._normalize_dim(input_, dim) input_size = input_.size() - # NOTE: we have to use concat-style all-gather here, - # stack-style all-gather has compatibility issues with - # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * world_size,) + input_size[1:] - # Allocate output tensor. - with self.use_symmetric_memory( - self, disabled=not self.is_allocation_symmetric() - ): - output_tensor = torch.empty( - output_size, dtype=input_.dtype, device=input_.device - ) - - # All-gather. if input_.is_cpu: - if is_shm_available(input_.dtype, self.world_size, self.local_size): + output_size = (input_size[0] * world_size,) + input_size[1:] + output_tensor = input_.new_empty(output_size) + if is_shm_available(input_.dtype, world_size, self.local_size): return torch.ops.sgl_kernel.shm_allgather(input_, dim) - else: - torch.distributed.all_gather_into_tensor( - output_tensor, input_, group=self.device_group - ) + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) else: - self.all_gather_into_tensor(output_tensor, input_) + output_tensor = self.impl.all_gather_into_tensor(input_) - # Reshape + # Reshape after all gather into tensor output_tensor = output_tensor.reshape((world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape( @@ -891,6 +657,7 @@ def all_gatherv( """ world_size = self.world_size pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None, "pynccl is required for all_gatherv" with pynccl_comm.change_state(enable=True): assert ( @@ -929,7 +696,7 @@ def _all_gather_allocate_output( pynccl_comm.group_start() for i, inp in enumerate(input_): - pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i]) + pynccl_comm.all_gather_impl(output_list[i], inp, sizes=size_list[i]) pynccl_comm.group_end() return output_list @@ -946,28 +713,10 @@ def gather( # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() + dim = self._normalize_dim(input_, dim) if self.xpu_communicator is not None and not self.xpu_communicator.disabled: - return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim) - # Allocate output tensor. - if self.rank_in_group == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather( - input_, gather_list, dst=self.ranks[dst], group=self.device_group - ) - if self.rank_in_group == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor + return self.xpu_communicator.gather(input_, dst, dim=dim) + return self.torch_comm.gather(input_, dst=dst, dim=dim) def broadcast(self, input_: torch.Tensor, src: int = 0): """Broadcast the input tensor. @@ -1363,6 +1112,13 @@ def destroy(self): if self.mq_broadcaster is not None: self.mq_broadcaster = None + @staticmethod + def _normalize_dim(input_: torch.Tensor, dim: int) -> int: + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + return dim if dim >= 0 else dim + input_.dim() + _WORLD: Optional[GroupCoordinator] = None @@ -1810,8 +1566,8 @@ def initialize_model_parallel( group_name="pdmux_prefill_tp", ) if _TP.pynccl_comm: - _TP.pynccl_comm.disabled = False - _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False + _TP.pynccl_comm._disabled = False + _PDMUX_PREFILL_TP_GROUP.pynccl_comm._disabled = False attn_dp_size = attention_data_parallel_size attn_cp_size = attention_context_model_parallel_size diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index f5fcd1757423..59a89dc3e45e 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -467,10 +467,7 @@ def _dp_gather_via_all_reduce( not local_tokens.dtype.is_floating_point and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE ): - from sglang.srt.distributed.parallel_state import inplace_all_reduce - - inplace_all_reduce(global_tokens, group_name=get_tp_group().unique_name) - + get_tp_group().all_reduce(global_tokens, inplace=True) else: global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens) @@ -581,4 +578,4 @@ def attn_cp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor): def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor): - return get_attention_tp_group().all_gather(input, output_tensor_list=output_list) + return get_attention_tp_group().all_gather_list(input, output_list) diff --git a/python/sglang/srt/model_executor/mindspore_runner.py b/python/sglang/srt/model_executor/mindspore_runner.py index 4cdcaed505b6..7b7d27560927 100644 --- a/python/sglang/srt/model_executor/mindspore_runner.py +++ b/python/sglang/srt/model_executor/mindspore_runner.py @@ -13,7 +13,7 @@ from mindspore._c_expression import GroupOptions from mindspore.communication import create_group -from sglang.srt.distributed.parallel_state import _groups +from sglang.srt.distributed.device_communicators.impl import _GROUPS logger = logging.getLogger(__name__) @@ -89,7 +89,7 @@ def set_ms_parallel_env(rank, local_rank, world_size, init_method): def reuse_hccl_comm(): - for group_name, group in _groups.items(): + for group_name, group in _GROUPS.items(): # Torch ProcessGroupHccl device_group = group().device_group hccl_comm_handle = device_group._get_backend(torch.device("npu")).get_hccl_comm( diff --git a/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py b/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py index 234d09342774..f4254d91968f 100644 --- a/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py +++ b/sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py @@ -258,7 +258,7 @@ def worker(world_size, rank, port, results_queue): # Measure latency torch.cuda.synchronize() start = time.perf_counter() - result_kernel = custom_ar.custom_all_reduce(inp_kernel) + result_kernel = custom_ar.all_reduce(inp_kernel) torch.cuda.synchronize() end = time.perf_counter() latencies_deterministic_kernel.append(end - start) diff --git a/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py b/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py index aa71259251a7..2b2eac009149 100644 --- a/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py +++ b/sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py @@ -112,7 +112,7 @@ def worker(world_size, rank, port): # Clone the same input inp = base_input.clone() - result = custom_ar.custom_all_reduce(inp) + result = custom_ar.all_reduce(inp) torch.cuda.synchronize() # Store checksum @@ -165,7 +165,7 @@ def worker(world_size, rank, port): # Flatten for all-reduce: (bs * hidden_dim,) batch_flat = batch.view(-1) - result_flat = custom_ar.custom_all_reduce(batch_flat) + result_flat = custom_ar.all_reduce(batch_flat) torch.cuda.synchronize() # Reshape back to (bs, hidden_dim) From 15cbe81c17101ac1ee418a620a56f5438f57a3a8 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 3 Apr 2026 12:51:58 +0800 Subject: [PATCH 2/6] fix: fix mindspore compatibility --- python/sglang/srt/distributed/device_communicators/impl.py | 4 +++- python/sglang/srt/distributed/parallel_state.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/distributed/device_communicators/impl.py b/python/sglang/srt/distributed/device_communicators/impl.py index 730a53ae816f..b57fb225ce43 100644 --- a/python/sglang/srt/distributed/device_communicators/impl.py +++ b/python/sglang/srt/distributed/device_communicators/impl.py @@ -25,6 +25,8 @@ class CommunicatorImpl: # NOTE: never use this, this is only kept for compatibility with # python/sglang/srt/model_executor/mindspore_runner.py device_group: torch.distributed.ProcessGroup + ranks: List[int] + local_rank: int def __post_init__(self): _register_group(self) @@ -139,7 +141,7 @@ def inplace_reduce_scatter( group.reduce_scatter_comms[method].reduce_scatter_tensor(input_, out=out) -@register_custom_op(mutates_args=["out"], out_shape="input_") +@register_custom_op(mutates_args=["out"]) def inplace_all_gather( input_: torch.Tensor, group_name: str, method: int, *, out: torch.Tensor ) -> None: diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index fb3087dbce9c..44a08c18550f 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -384,6 +384,8 @@ def _filter(*items) -> List: self.torch_comm, ), device_group=self.device_group, + local_rank=self.local_rank, + ranks=self.ranks, ) def __repr__(self): From e5e1841ebf3318400d48f007ec7af60cf295b27b Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 3 Apr 2026 13:21:35 +0800 Subject: [PATCH 3/6] minor: fix circular import --- .../distributed/device_communicators/{__init__.py => comm.py} | 0 python/sglang/srt/distributed/parallel_state.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename python/sglang/srt/distributed/device_communicators/{__init__.py => comm.py} (100%) diff --git a/python/sglang/srt/distributed/device_communicators/__init__.py b/python/sglang/srt/distributed/device_communicators/comm.py similarity index 100% rename from python/sglang/srt/distributed/device_communicators/__init__.py rename to python/sglang/srt/distributed/device_communicators/comm.py diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 44a08c18550f..8731e0041b35 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -263,7 +263,7 @@ def __init__( self.use_message_queue_broadcaster = use_message_queue_broadcaster # Lazy import to avoid documentation build error - import sglang.srt.distributed.device_communicators as comm + import sglang.srt.distributed.device_communicators.comm as comm from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) From 6a7f46b8c656114ae0adb8a36a540c0b5f90b711 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 3 Apr 2026 15:03:27 +0800 Subject: [PATCH 4/6] fix: some deficiencies --- .../srt/distributed/device_communicators/impl.py | 4 +++- .../srt/distributed/device_communicators/pynccl.py | 1 - .../distributed/device_communicators/torch_symm_mem.py | 10 ++++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/impl.py b/python/sglang/srt/distributed/device_communicators/impl.py index b57fb225ce43..621589780c66 100644 --- a/python/sglang/srt/distributed/device_communicators/impl.py +++ b/python/sglang/srt/distributed/device_communicators/impl.py @@ -51,7 +51,8 @@ def all_reduce( if mode is not None and _is_mode_supported(mode, inplace): if not comm.should_use_custom_op(): return comm.all_reduce(input_, inplace=inplace) - if _can_use_inplace(mode): # we prefer in-place if possible + use_inplace = _can_use_inplace(mode) if inplace is None else inplace + if use_inplace: inplace_all_reduce(input_, self.unique_name, i) return input_ else: @@ -160,6 +161,7 @@ def outplace_all_gather( # NOTE(dark): we don't make them class method due to conflict with piecewise cuda graph +# we prefer in-place if possible def _can_use_inplace(mode: AllReduceMode) -> bool: return mode.value != "outplace" diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index 31f87e8a559f..8a40102ea3e7 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -128,7 +128,6 @@ def graph_capture_context(self): return self.change_state(enable=True) def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: - # NOTE: inplace all-reduce is not compatible with piecewise CUDA graph return AllReduceMode.BOTH @BaseCommunicator.validate diff --git a/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py b/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py index fafc1cca1c0f..9cfb3b3831ff 100644 --- a/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py +++ b/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py @@ -76,16 +76,18 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]): self.device_capability = torch.cuda.get_device_capability(device)[0] if self.device_capability < 9: raise RuntimeError( - "TorchSymmMemCommunicator: Device capability %s not supported, " - "communicator is not available.".format(self.device_capability) + "TorchSymmMemCommunicator: " + f"Device capability {self.device_capability} not supported, " + "communicator is not available." ) if ( self.world_size not in TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability] ): raise RuntimeError( - "TorchSymmMemCommunicator: World size %d not supported, " - "communicator is not available.".format(self.world_size) + "TorchSymmMemCommunicator: " + f"World size {self.world_size} not supported, " + "communicator is not available." ) self.max_size = TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ self.world_size From 74a098236a98d82306ac7d52105c75dce6cbeee4 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Sat, 4 Apr 2026 13:25:49 +0800 Subject: [PATCH 5/6] misc: rename --- .../srt/distributed/device_communicators/base.py | 10 +++++----- .../device_communicators/custom_all_reduce.py | 4 +--- .../device_communicators/custom_all_reduce_aiter.py | 2 +- .../device_communicators/custom_all_reduce_v2.py | 4 +--- .../device_communicators/hpu_communicator.py | 2 +- .../srt/distributed/device_communicators/impl.py | 2 +- .../srt/distributed/device_communicators/pymscclpp.py | 4 +--- .../srt/distributed/device_communicators/pynccl.py | 2 +- .../distributed/device_communicators/pynccl_symm.py | 2 +- .../device_communicators/quick_all_reduce.py | 2 +- .../distributed/device_communicators/torch_symm_mem.py | 2 +- .../distributed/device_communicators/torch_wrapper.py | 2 +- .../device_communicators/xpu_communicator.py | 2 +- 13 files changed, 17 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/base.py b/python/sglang/srt/distributed/device_communicators/base.py index 656721b54ecc..69648be80eb4 100644 --- a/python/sglang/srt/distributed/device_communicators/base.py +++ b/python/sglang/srt/distributed/device_communicators/base.py @@ -122,7 +122,7 @@ def should_use_custom_op(self) -> bool: """ return False - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: """ Report the preferred all-reduce mode for `input_`. @@ -149,15 +149,15 @@ def all_reduce( Run all-reduce on `input_`. Preconditions: - 1. `self.can_all_reduce(input_)` must not return `None`. - 2. `inplace=True` requires `self.can_all_reduce(input_).can_inplace()`. - 3. `inplace=False` requires `self.can_all_reduce(input_).can_outplace()`. + 1. `self.get_all_reduce_mode(input_)` must not return `None`. + 2. `inplace=True` requires `self.get_all_reduce_mode(input_).can_inplace()`. + 3. `inplace=False` requires `self.get_all_reduce_mode(input_).can_outplace()`. 4. `self.disabled` must be `False`. :param input_: Input tensor for the all-reduce. :param inplace: Whether the operation should be in-place. If `None`, the communicator may choose its preferred mode. If specified, it must be - consistent with `can_all_reduce(input_)`. + consistent with `get_all_reduce_mode(input_)`. :return: The reduced tensor. If the operation is in-place, this must be `input_` itself. """ diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 929c4cd4e391..c4ca30b187cc 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -258,7 +258,7 @@ def register_graph_buffers(self): offsets = [d[1] for d in all_data] # type: ignore ops.register_graph_buffers(self._ptr, handles, offsets) - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: if self.disabled: return None inp_size = input_.numel() * input_.element_size() @@ -316,8 +316,6 @@ def all_reduce( ) -> torch.Tensor: """The main all-reduce API with CUDA-graph-aware behavior.""" self.assert_outplace("all_reduce", inplace) - if self.can_all_reduce(input_) is None: - raise ValueError("custom_all_reduce cannot handle this input") if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): return self._all_reduce_impl(input_, registered=not self.tms_cudagraph) diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py index 4ce4fcb8c06f..1facff855636 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_aiter.py @@ -28,7 +28,7 @@ def should_use_custom_op(self) -> bool: def disabled(self) -> bool: return self._disabled or self.comm.disabled - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: can_use = self.comm.should_custom_ar(input_) return AllReduceMode.OUTPLACE if can_use else None diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py index 32c110a72399..678f9e6337b1 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py @@ -108,7 +108,7 @@ def capture(self): self.obj.register_inputs(result) log_info_on_rank0(logger, f"Registering {len(pairs)} cuda graph addresses") - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: if self.disabled: return None inp_size = input_.numel() * input_.element_size() @@ -127,8 +127,6 @@ def all_reduce( inplace: Optional[bool] = None, ) -> torch.Tensor: self.assert_outplace("all_reduce", inplace) - if self.can_all_reduce(input_) is None: - raise ValueError("custom_all_reduce_v2 cannot handle this input") if is_in_piecewise_cuda_graph(): # disable inplace optimization try: self.obj.set_cuda_graph_capture(False) diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 42a023bfe9df..11ff321b8aeb 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -21,7 +21,7 @@ def __init__(self, group: ProcessGroup): self.group = group super().__init__(dist.get_world_size(self.group), disabled=not is_hpu()) - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: return AllReduceMode.INPLACE @BaseCommunicator.validate diff --git a/python/sglang/srt/distributed/device_communicators/impl.py b/python/sglang/srt/distributed/device_communicators/impl.py index 621589780c66..ad377889b48e 100644 --- a/python/sglang/srt/distributed/device_communicators/impl.py +++ b/python/sglang/srt/distributed/device_communicators/impl.py @@ -47,7 +47,7 @@ def all_reduce( for i, comm in enumerate(self.all_reduce_comms): if comm.disabled: continue - mode = comm.can_all_reduce(input_) + mode = comm.get_all_reduce_mode(input_) if mode is not None and _is_mode_supported(mode, inplace): if not comm.should_use_custom_op(): return comm.all_reduce(input_, inplace=inplace) diff --git a/python/sglang/srt/distributed/device_communicators/pymscclpp.py b/python/sglang/srt/distributed/device_communicators/pymscclpp.py index 5acced649896..0625989a3cfe 100644 --- a/python/sglang/srt/distributed/device_communicators/pymscclpp.py +++ b/python/sglang/srt/distributed/device_communicators/pymscclpp.py @@ -261,7 +261,7 @@ def graph_capture_context(self): def should_use_custom_op(self) -> bool: return True - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: if self.disabled or self._context is None: return None if input_.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE: @@ -280,8 +280,6 @@ def all_reduce( inplace: Optional[bool] = None, ) -> torch.Tensor: self.assert_outplace("all_reduce", inplace) - if self.can_all_reduce(input_) is None: - raise ValueError("pymscclpp cannot handle this input") msg_size = input_.numel() * input_.element_size() index = bisect.bisect_left(self.msg_size_for_finetune, msg_size) msg_size_finetune = self.msg_size_for_finetune[index] diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index 8a40102ea3e7..69b0bb5e0eda 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -127,7 +127,7 @@ def _resolve_stream(self) -> torch.cuda.Stream: def graph_capture_context(self): return self.change_state(enable=True) - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: return AllReduceMode.BOTH @BaseCommunicator.validate diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_symm.py b/python/sglang/srt/distributed/device_communicators/pynccl_symm.py index 0e0a937f9f87..6edf0f23b04c 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_symm.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_symm.py @@ -29,7 +29,7 @@ def disabled(self) -> bool: def should_use_custom_op(self) -> bool: return True - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: # always inplace return AllReduceMode.INPLACE diff --git a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py index de2cf603583d..2f0e026d83a1 100644 --- a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py @@ -199,7 +199,7 @@ def create_shared_buffer(self): def should_use_custom_op(self) -> bool: return True - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: if self.disabled: return None if input_.dtype not in self._SUPPORTED_DTYPES: diff --git a/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py b/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py index 9cfb3b3831ff..247609d9f3d4 100644 --- a/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py +++ b/python/sglang/srt/distributed/device_communicators/torch_symm_mem.py @@ -109,7 +109,7 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]): def should_use_custom_op(self) -> bool: return True - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: """ Fast-path eligibility check for a given tensor. diff --git a/python/sglang/srt/distributed/device_communicators/torch_wrapper.py b/python/sglang/srt/distributed/device_communicators/torch_wrapper.py index 6ab368c7f727..af103ea06a46 100644 --- a/python/sglang/srt/distributed/device_communicators/torch_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/torch_wrapper.py @@ -24,7 +24,7 @@ def change_state(self, enable: bool): assert enable, "TorchDefaultCommunicator cannot be disabled" return super().change_state(enable) - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: return AllReduceMode.INPLACE def all_reduce( diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py index 78c7166f66b7..be58d00010d6 100644 --- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -19,7 +19,7 @@ def __init__(self, rank_in_group: int, group: ProcessGroup): self.group = group super().__init__(dist.get_world_size(group), disabled=not is_xpu()) - def can_all_reduce(self, input_: torch.Tensor) -> Optional[AllReduceMode]: + def get_all_reduce_mode(self, input_: torch.Tensor) -> Optional[AllReduceMode]: return AllReduceMode.INPLACE @BaseCommunicator.validate From ae34f4bdb69ebc61c9fcc445ed4209d141b529fd Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Sat, 4 Apr 2026 13:32:26 +0800 Subject: [PATCH 6/6] fix: fix custom ar v2 --- python/sglang/jit_kernel/tests/test_custom_all_reduce.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py index 7f48fa6e0298..fe80d912642f 100644 --- a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py +++ b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py @@ -213,7 +213,6 @@ def run_eager(x: torch.Tensor) -> torch.Tensor: for _ in range(TEST_LOOP): # NOTE: 15 * 8 < 128, which is the precision limit for bf16 inp = torch.randint(0, 16, (TEST_LAYERS, size), dtype=dtype, device=device) - assert comm.can_all_reduce(inp[0]) out_ref = inp.clone() dist.all_reduce(out_ref, group=nccl_group) out_jit = run_fn(inp)