Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/advanced_features/dp_dpa_smg_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,31 @@ python -m sglang.launch_server \

Note that MLA models, of course, also support DP. Suppose you want to enable standard DP for MLA models. First, launch each MLA model's replica independently. You may launch these replicas one by one with DPA enabled. After launching each MLA model's replica, launch an SMG and connect all the replicas to the SMG. A detailed explanation of SMG is as follows.

### Fused communication via FlashInfer MixedComm

When DPA is combined with attention tensor parallelism, each layer issues two communication kernels (reducescatter + allgather) both after attention and FFN. SGLang can use MixedComm kernels from [FlashInfer](https://github.com/flashinfer-ai/flashinfer) to fuse two communication kernels into a single kernel.

The communication handler is constructed once per process during initialization (outside CUDA-graph capture). The performance of some predefined communication sizes is also tested during initialization to enable autotuning.

Requirements for MixedComm to be enabled:
- `--enable-dp-attention` is set and `--dp-size > 1`.
- GPUs are SM90 (e.g., H200) or SM100 (e.g., B200)**.
- An NVSwitch / NVLink-Switch fabric is present on each node.
- Attention TP size and the number of GPUs per node are divisors of one another.
- The installed version of FlashInfer contains `flashinfer.comm.mixed_comm`.

To enable MixedComm, the environment variable `SGLANG_USE_MIXED_COMM` needs to be set to `1` or `true` before launching the server. It is **off by default**.

```bash
export SGLANG_USE_MIXED_COMM=1
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3 \
--tp 8 \
--dp-size 2 \
--enable-dp-attention \
...
```

## Modern Data Parallelism SGLang Model Gateway (SMG)

### Native DP Mode
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,14 +1212,14 @@ def _scatter_hidden_states(
get_local_dp_buffer(),
hidden_states,
)
if should_use_dp_reduce_scatterv():
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
elif should_use_dp_reduce_scatterv():
get_tp_group().reduce_scatterv(
global_hidden_states,
output=hidden_states,
sizes=get_dp_global_num_tokens(),
)
elif allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
else:
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
Expand Down
104 changes: 104 additions & 0 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import logging
import os
from contextlib import contextmanager
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Tuple
Expand Down Expand Up @@ -49,6 +50,71 @@
_is_hip = is_hip()
_USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")

# When enabled, route DP attention gather/reduce-scatter through flashinfer's
# fused mixed_comm kernels (virtual-memory intra-node + nvshmem inter-node).
_USE_MIXED_COMM = get_bool_env_var("SGLANG_USE_MIXED_COMM")

if _USE_MIXED_COMM:
from flashinfer.comm.mixed_comm import (
MixedCommHandler,
MixedCommOp,
run_mixed_comm,
)

_MIXED_COMM_HANDLER = None


def _init_mixed_comm_handler(dtype: torch.dtype, device: torch.device):
"""Construct the process-wide MixedCommHandler.

Must be called once per process, at a point outside any CUDA graph capture,
because the constructor performs collective init (new_group, broadcast,
barrier), nvshmem setup, and autotune. Topology maps flashinfer TP/DP to
sglang attention_tp/attention_dp.
"""
global _MIXED_COMM_HANDLER
assert _USE_MIXED_COMM
assert _MIXED_COMM_HANDLER is None, "MixedCommHandler already initialized."

world_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
assert world_size % local_size == 0, (
f"SGLANG_USE_MIXED_COMM: world_size={world_size} is not a multiple of "
f"LOCAL_WORLD_SIZE={local_size}."
)
local_rank = world_rank % local_size
inter_size = world_size // local_size
inter_rank = world_rank // local_size

if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())

attn_tp_size = get_attention_tp_size()
assert attn_tp_size % local_size == 0 or local_size % attn_tp_size == 0, (
f"SGLANG_USE_MIXED_COMM: the larger one of attn_tp_size={attn_tp_size} "
f"and LOCAL_WORLD_SIZE={local_size} must be a multiple of the smaller one."
)
local_tp_size = min(attn_tp_size, local_size)
local_dp_size = local_size // local_tp_size
inter_tp_size = attn_tp_size // local_tp_size
inter_dp_size = inter_size // inter_tp_size

_MIXED_COMM_HANDLER = MixedCommHandler(
world_rank=world_rank,
world_size=world_size,
local_rank=local_rank,
local_size=local_size,
inter_rank=inter_rank,
inter_size=inter_size,
local_tp_size=local_tp_size,
local_dp_size=local_dp_size,
inter_tp_size=inter_tp_size,
inter_dp_size=inter_dp_size,
dtype=dtype,
device=device,
)


class DpPaddingMode(IntEnum):

Expand Down Expand Up @@ -310,6 +376,11 @@ def initialize_dp_attention(
device=torch.device(server_args.device),
)

# Construct the MixedCommHandler eagerly so its collective/nvshmem init and
# autotune run outside any CUDA graph capture that happens later.
if _USE_MIXED_COMM:
_init_mixed_comm_handler(model_config.dtype, torch.device(server_args.device))


def is_dp_attention_enabled() -> bool:
return _ENABLE_DP_ATTENTION_FLAG
Expand Down Expand Up @@ -484,6 +555,25 @@ def _dp_gather_via_all_gather(
forward_batch: ForwardBatch,
is_partial: bool,
):
if _USE_MIXED_COMM:
if get_attention_tp_size() == 1:
run_mixed_comm(
MixedCommOp.ALLGATHER,
_MIXED_COMM_HANDLER,
local_tokens,
global_tokens,
)
else:
if not is_partial and get_attention_tp_rank() != 0:
local_tokens.fill_(0)
run_mixed_comm(
MixedCommOp.ALLREDUCE_ALLGATHER,
_MIXED_COMM_HANDLER,
local_tokens,
global_tokens,
)
return

if get_attention_tp_size() == 1:
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
return
Expand Down Expand Up @@ -553,6 +643,20 @@ def dp_scatter(


def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
if _USE_MIXED_COMM:
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
run_mixed_comm(
MixedCommOp.REDUCESCATTER, _MIXED_COMM_HANDLER, input, output
)
else:
run_mixed_comm(
MixedCommOp.REDUCESCATTER_ALLREDUCE,
_MIXED_COMM_HANDLER,
input,
output,
)
return

if get_tensor_model_parallel_world_size() == get_attention_dp_size():
get_tp_group().reduce_scatter_tensor(output, input)
else:
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_size,
is_dp_attention_enabled,
)

Expand Down Expand Up @@ -346,6 +347,7 @@ def should_use_dp_reduce_scatterv():
and get_moe_a2a_backend().is_none()
and is_dp_attention_enabled()
and get_attention_dp_size() > 1
and get_attention_tp_size() == 1
and get_moe_expert_parallel_world_size() == get_attention_dp_size()
)

Expand Down
46 changes: 46 additions & 0 deletions test/registered/distributed/test_dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,52 @@ def tearDownClass(cls):
kill_process_tree(cls.process.pid)


class TestDPAttentionDP2TP2MixedComm(
CustomTestCase,
GSM8KMixin,
JSONConstrainedMixin,
EBNFConstrainedMixin,
RegexConstrainedMixin,
):
gsm8k_accuracy_thres = 0.6

@classmethod
def setUpClass(cls):
try:
import flashinfer.comm.mixed_comm # noqa: F401
except ImportError as e:
raise unittest.SkipTest(f"flashinfer.comm.mixed_comm unavailable: {e}")

cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls._env_override = envs.SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP.override(
True
)
cls._env_override.__enter__()
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
"--enable-torch-compile",
"--torch-compile-max-bs",
"2",
],
env={"SGLANG_USE_MIXED_COMM": "1"},
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
cls._env_override.__exit__(None, None, None)


class TestDPRetract(
CustomTestCase,
JSONConstrainedMixin,
Expand Down
50 changes: 50 additions & 0 deletions test/registered/distributed/test_dp_attention_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,56 @@ def test_gsm8k(self):
self.assertGreater(metrics["score"], 0.8)


@unittest.skipIf(
is_in_amd_ci(),
"MixedComm path is CUDA-only (flashinfer.comm.mixed_comm)",
)
class TestDPAttentionDP2TP4MixedComm(
CustomTestCase,
JSONConstrainedMixin,
EBNFConstrainedMixin,
RegexConstrainedMixin,
):
@classmethod
def setUpClass(cls):
try:
import flashinfer.comm.mixed_comm # noqa: F401
except ImportError as e:
raise unittest.SkipTest(f"flashinfer.comm.mixed_comm unavailable: {e}")

cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp=4",
"--enable-dp-attention",
"--dp=2",
],
env={"SGLANG_USE_MIXED_COMM": "1"},
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="gsm8k",
num_examples=None,
num_threads=1024,
)

metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)


@unittest.skipIf(
is_in_amd_ci(),
"DeepSeek MTP forward_mla NameError on AMD + needs 8 GPUs",
Expand Down
Loading