diff --git a/docs/advanced_features/dp_dpa_smg_guide.md b/docs/advanced_features/dp_dpa_smg_guide.md index 9ec5df64856e..9c5665e944c8 100644 --- a/docs/advanced_features/dp_dpa_smg_guide.md +++ b/docs/advanced_features/dp_dpa_smg_guide.md @@ -117,6 +117,31 @@ python -m sglang.launch_server \ Note that MLA models, of course, also support DP. Suppose you want to enable standard DP for MLA models. First, launch each MLA model's replica independently. You may launch these replicas one by one with DPA enabled. After launching each MLA model's replica, launch an SMG and connect all the replicas to the SMG. A detailed explanation of SMG is as follows. +### Fused communication via FlashInfer MixedComm + +When DPA is combined with attention tensor parallelism, each layer issues two communication kernels (reducescatter + allgather) both after attention and FFN. SGLang can use MixedComm kernels from [FlashInfer](https://github.com/flashinfer-ai/flashinfer) to fuse two communication kernels into a single kernel. + +The communication handler is constructed once per process during initialization (outside CUDA-graph capture). The performance of some predefined communication sizes is also tested during initialization to enable autotuning. + +Requirements for MixedComm to be enabled: +- `--enable-dp-attention` is set and `--dp-size > 1`. +- GPUs are SM90 (e.g., H200) or SM100 (e.g., B200)**. +- An NVSwitch / NVLink-Switch fabric is present on each node. +- Attention TP size and the number of GPUs per node are divisors of one another. +- The installed version of FlashInfer contains `flashinfer.comm.mixed_comm`. + +To enable MixedComm, the environment variable `SGLANG_USE_MIXED_COMM` needs to be set to `1` or `true` before launching the server. It is **off by default**. + +```bash +export SGLANG_USE_MIXED_COMM=1 +python -m sglang.launch_server \ + --model-path deepseek-ai/DeepSeek-V3 \ + --tp 8 \ + --dp-size 2 \ + --enable-dp-attention \ + ... +``` + ## Modern Data Parallelism SGLang Model Gateway (SMG) ### Native DP Mode diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 853cf3ad50a8..3ab5f53b377a 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -1212,14 +1212,14 @@ def _scatter_hidden_states( get_local_dp_buffer(), hidden_states, ) - if should_use_dp_reduce_scatterv(): + if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): + dp_reduce_scatter_tensor(hidden_states, global_hidden_states) + elif should_use_dp_reduce_scatterv(): get_tp_group().reduce_scatterv( global_hidden_states, output=hidden_states, sizes=get_dp_global_num_tokens(), ) - elif allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): - dp_reduce_scatter_tensor(hidden_states, global_hidden_states) else: dp_scatter(hidden_states, global_hidden_states, forward_batch) return hidden_states, residual diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 09d307aeb73f..097efcf3a3ad 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -2,6 +2,7 @@ import functools import logging +import os from contextlib import contextmanager from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional, Tuple @@ -49,6 +50,71 @@ _is_hip = is_hip() _USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A") +# When enabled, route DP attention gather/reduce-scatter through flashinfer's +# fused mixed_comm kernels (virtual-memory intra-node + nvshmem inter-node). +_USE_MIXED_COMM = get_bool_env_var("SGLANG_USE_MIXED_COMM") + +if _USE_MIXED_COMM: + from flashinfer.comm.mixed_comm import ( + MixedCommHandler, + MixedCommOp, + run_mixed_comm, + ) + +_MIXED_COMM_HANDLER = None + + +def _init_mixed_comm_handler(dtype: torch.dtype, device: torch.device): + """Construct the process-wide MixedCommHandler. + + Must be called once per process, at a point outside any CUDA graph capture, + because the constructor performs collective init (new_group, broadcast, + barrier), nvshmem setup, and autotune. Topology maps flashinfer TP/DP to + sglang attention_tp/attention_dp. + """ + global _MIXED_COMM_HANDLER + assert _USE_MIXED_COMM + assert _MIXED_COMM_HANDLER is None, "MixedCommHandler already initialized." + + world_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + assert world_size % local_size == 0, ( + f"SGLANG_USE_MIXED_COMM: world_size={world_size} is not a multiple of " + f"LOCAL_WORLD_SIZE={local_size}." + ) + local_rank = world_rank % local_size + inter_size = world_size // local_size + inter_rank = world_rank // local_size + + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + + attn_tp_size = get_attention_tp_size() + assert attn_tp_size % local_size == 0 or local_size % attn_tp_size == 0, ( + f"SGLANG_USE_MIXED_COMM: the larger one of attn_tp_size={attn_tp_size} " + f"and LOCAL_WORLD_SIZE={local_size} must be a multiple of the smaller one." + ) + local_tp_size = min(attn_tp_size, local_size) + local_dp_size = local_size // local_tp_size + inter_tp_size = attn_tp_size // local_tp_size + inter_dp_size = inter_size // inter_tp_size + + _MIXED_COMM_HANDLER = MixedCommHandler( + world_rank=world_rank, + world_size=world_size, + local_rank=local_rank, + local_size=local_size, + inter_rank=inter_rank, + inter_size=inter_size, + local_tp_size=local_tp_size, + local_dp_size=local_dp_size, + inter_tp_size=inter_tp_size, + inter_dp_size=inter_dp_size, + dtype=dtype, + device=device, + ) + class DpPaddingMode(IntEnum): @@ -310,6 +376,11 @@ def initialize_dp_attention( device=torch.device(server_args.device), ) + # Construct the MixedCommHandler eagerly so its collective/nvshmem init and + # autotune run outside any CUDA graph capture that happens later. + if _USE_MIXED_COMM: + _init_mixed_comm_handler(model_config.dtype, torch.device(server_args.device)) + def is_dp_attention_enabled() -> bool: return _ENABLE_DP_ATTENTION_FLAG @@ -484,6 +555,25 @@ def _dp_gather_via_all_gather( forward_batch: ForwardBatch, is_partial: bool, ): + if _USE_MIXED_COMM: + if get_attention_tp_size() == 1: + run_mixed_comm( + MixedCommOp.ALLGATHER, + _MIXED_COMM_HANDLER, + local_tokens, + global_tokens, + ) + else: + if not is_partial and get_attention_tp_rank() != 0: + local_tokens.fill_(0) + run_mixed_comm( + MixedCommOp.ALLREDUCE_ALLGATHER, + _MIXED_COMM_HANDLER, + local_tokens, + global_tokens, + ) + return + if get_attention_tp_size() == 1: get_tp_group().all_gather_into_tensor(global_tokens, local_tokens) return @@ -553,6 +643,20 @@ def dp_scatter( def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): + if _USE_MIXED_COMM: + if get_tensor_model_parallel_world_size() == get_attention_dp_size(): + run_mixed_comm( + MixedCommOp.REDUCESCATTER, _MIXED_COMM_HANDLER, input, output + ) + else: + run_mixed_comm( + MixedCommOp.REDUCESCATTER_ALLREDUCE, + _MIXED_COMM_HANDLER, + input, + output, + ) + return + if get_tensor_model_parallel_world_size() == get_attention_dp_size(): get_tp_group().reduce_scatter_tensor(output, input) else: diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index e05167da972a..c3cf8be3d082 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -11,6 +11,7 @@ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.layers.dp_attention import ( get_attention_dp_size, + get_attention_tp_size, is_dp_attention_enabled, ) @@ -346,6 +347,7 @@ def should_use_dp_reduce_scatterv(): and get_moe_a2a_backend().is_none() and is_dp_attention_enabled() and get_attention_dp_size() > 1 + and get_attention_tp_size() == 1 and get_moe_expert_parallel_world_size() == get_attention_dp_size() ) diff --git a/test/registered/distributed/test_dp_attention.py b/test/registered/distributed/test_dp_attention.py index 53d86336a2bd..01a3371c464a 100644 --- a/test/registered/distributed/test_dp_attention.py +++ b/test/registered/distributed/test_dp_attention.py @@ -100,6 +100,52 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) +class TestDPAttentionDP2TP2MixedComm( + CustomTestCase, + GSM8KMixin, + JSONConstrainedMixin, + EBNFConstrainedMixin, + RegexConstrainedMixin, +): + gsm8k_accuracy_thres = 0.6 + + @classmethod + def setUpClass(cls): + try: + import flashinfer.comm.mixed_comm # noqa: F401 + except ImportError as e: + raise unittest.SkipTest(f"flashinfer.comm.mixed_comm unavailable: {e}") + + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls._env_override = envs.SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP.override( + True + ) + cls._env_override.__enter__() + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--enable-dp-attention", + "--dp", + "2", + "--enable-torch-compile", + "--torch-compile-max-bs", + "2", + ], + env={"SGLANG_USE_MIXED_COMM": "1"}, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + cls._env_override.__exit__(None, None, None) + + class TestDPRetract( CustomTestCase, JSONConstrainedMixin, diff --git a/test/registered/distributed/test_dp_attention_large.py b/test/registered/distributed/test_dp_attention_large.py index 0a21c5ac752c..dc13dc471266 100644 --- a/test/registered/distributed/test_dp_attention_large.py +++ b/test/registered/distributed/test_dp_attention_large.py @@ -70,6 +70,56 @@ def test_gsm8k(self): self.assertGreater(metrics["score"], 0.8) +@unittest.skipIf( + is_in_amd_ci(), + "MixedComm path is CUDA-only (flashinfer.comm.mixed_comm)", +) +class TestDPAttentionDP2TP4MixedComm( + CustomTestCase, + JSONConstrainedMixin, + EBNFConstrainedMixin, + RegexConstrainedMixin, +): + @classmethod + def setUpClass(cls): + try: + import flashinfer.comm.mixed_comm # noqa: F401 + except ImportError as e: + raise unittest.SkipTest(f"flashinfer.comm.mixed_comm unavailable: {e}") + + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp=4", + "--enable-dp-attention", + "--dp=2", + ], + env={"SGLANG_USE_MIXED_COMM": "1"}, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.8) + + @unittest.skipIf( is_in_amd_ci(), "DeepSeek MTP forward_mla NameError on AMD + needs 8 GPUs",