Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
52b050d
implement gather before attn
ch-wan May 18, 2025
abe5c79
Merge branch 'main' into cheng/gather_before_attn
zhyncs May 18, 2025
d9a745a
wip: --enable-ep-moe drops accuracy
ch-wan May 18, 2025
bdbd802
Merge branch 'main' into cheng/fix-6297
ch-wan May 18, 2025
2f34132
Merge branch 'main' into cheng/gather_before_attn
zhyncs May 18, 2025
824254c
Merge branch 'main' into cheng/gather_before_attn
ch-wan May 19, 2025
b324ca0
Merge branch 'main' into cheng/fix-6297
ch-wan May 19, 2025
35267c3
Merge commit 'eb8f02dd87acd8689c41d15a7c0f11f5eff914d0' into cheng/ga…
ch-wan May 27, 2025
7845798
update communicator
ch-wan May 27, 2025
95df400
Merge branch 'cheng/gather_before_attn' into cheng/fix-6297
ch-wan May 27, 2025
f6fc191
Merge commit 'b18416fbf869fd2d150937d3efcf9e75ee3fb278' into cheng/ga…
ch-wan May 27, 2025
eed7ce3
fmt
ch-wan May 27, 2025
cdd6e9b
Merge branch 'cheng/gather_before_attn' into cheng/fix-6297
ch-wan May 27, 2025
81c53e0
Merge branch 'main' into cheng/gather_before_attn
zhyncs May 27, 2025
75773aa
fix
ch-wan May 27, 2025
8b138a1
fix
ch-wan May 27, 2025
fcaaeaf
fmt
ch-wan May 27, 2025
8c09414
Merge branch 'main' into cheng/gather_before_attn
ch-wan May 28, 2025
c15836b
Merge commit '8c09414e29636629f0c0544591f12933073ce5c5' into cheng/fi…
ch-wan May 28, 2025
cd6d94c
Merge commit 'b1c8d4e9f31953560f2db45a3b6e68099ef00c13' into cheng/fi…
ch-wan May 28, 2025
1870255
format
ch-wan May 28, 2025
f8d152c
fix
ch-wan May 28, 2025
e170745
Update communicator.py
ch-wan May 28, 2025
c80f333
Update communicator.py
ch-wan May 28, 2025
0c92871
fix
ch-wan May 28, 2025
a269c5b
Merge branch 'main' into cheng/gather_before_attn
ch-wan Jun 3, 2025
7a3131e
format
ch-wan Jun 3, 2025
fe98477
Merge commit '7a3131e6868ec903d03156d1765c6722b9fbcaad' into cheng/fi…
ch-wan Jun 3, 2025
260bbef
multiple fixes
ch-wan Jun 4, 2025
8cc1fea
format
ch-wan Jun 4, 2025
e78fa2c
Merge commit '4f723edd3baf3823eddfb9d6426548daba17c687' into cheng/fi…
ch-wan Jun 15, 2025
2bc64aa
Merge branch 'main' into cheng/gather_before_attn
ch-wan Jun 15, 2025
e596568
Merge branch 'main' into cheng/gather_before_attn
zhyncs Jun 15, 2025
bfdec93
Merge branch 'main' into cheng/gather_before_attn
ch-wan Jun 15, 2025
80d2f03
Merge remote-tracking branch 'origin/cheng/gather_before_attn' into c…
ch-wan Jun 15, 2025
dd99fb9
Merge commit '80d2f03ecde5992184f6128538126691ee89f04b' into cheng/fi…
ch-wan Jun 16, 2025
ca887b8
refactor
ch-wan Jun 16, 2025
ead7d9f
revert dp attention test
ch-wan Jun 16, 2025
0397727
add a hybrid test
ch-wan Jun 16, 2025
33fb878
format
ch-wan Jun 16, 2025
8c1789d
Merge commit '0ae1e9a75573b5159afb9149b8c80ae76f239ff7' into cheng/fi…
ch-wan Jun 16, 2025
e061123
Update test_hybrid_dp_ep_tp.py
ch-wan Jun 16, 2025
d887f77
update tests
ch-wan Jun 16, 2025
713f87b
update test
ch-wan Jun 18, 2025
d39549d
fix
ch-wan Jun 18, 2025
dcf58ee
Merge commit 'e56685ac1bb881e58043fe5f2c4ae055905332ba' into cheng/fi…
ch-wan Jun 18, 2025
84c31ff
update tests
ch-wan Jun 18, 2025
57857af
adapt to dp attn with mtp
ch-wan Jun 18, 2025
1e7012e
format
ch-wan Jun 18, 2025
5a1d9bd
fix mlp sync
ch-wan Jun 18, 2025
937451b
Merge commit '09ae5b20f3123487f36097d284a1f535cd267e7b' into cheng/fi…
ch-wan Jun 18, 2025
e4aab35
format
ch-wan Jun 18, 2025
a14115f
add MTP tests
ch-wan Jun 20, 2025
3c6de24
update file name and intro
ch-wan Jun 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
configure_logger,
get_bool_env_var,
kill_process_tree,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
suppress_other_loggers,
)
Expand Down Expand Up @@ -243,7 +245,7 @@ def extend(reqs, model_runner):
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
_maybe_prepare_dp_attn_batch(batch, model_runner)
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
Expand All @@ -255,26 +257,26 @@ def extend(reqs, model_runner):
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
_maybe_prepare_dp_attn_batch(batch, model_runner)
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits


def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
if model_runner.server_args.enable_dp_attention:
Scheduler.prepare_dp_attn_batch_raw(
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
if require_mlp_sync(model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
tp_cpu_group=model_runner.tp_group.cpu_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
)


Expand Down
33 changes: 14 additions & 19 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import require_mlp_sync

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -645,10 +646,7 @@ def event_loop_normal_disagg_decode(self: Scheduler):
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch

prepare_dp_attn_flag = (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)

if batch:
# Generate fake extend output.
Expand All @@ -657,14 +655,14 @@ def event_loop_normal_disagg_decode(self: Scheduler):
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag:
if prepare_mlp_sync_flag:
self._prepare_idle_batch_and_run(None)
else:
if prepare_dp_attn_flag:
self.prepare_dp_attn_batch(batch)
if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
elif prepare_dp_attn_flag:
elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None)

if batch is None and (
Expand Down Expand Up @@ -695,10 +693,7 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
self.cur_batch = batch
last_batch_in_queue = False

prepare_dp_attn_flag = (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)

if batch:
# Generate fake extend output.
Expand All @@ -707,16 +702,16 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag:
if prepare_mlp_sync_flag:
batch_, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
if batch_:
result_queue.append((batch_.copy(), result))
last_batch_in_queue = True
else:
if prepare_dp_attn_flag:
self.prepare_dp_attn_batch(batch)
if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))

Expand All @@ -731,7 +726,7 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
self.set_next_batch_sampling_info_done(tmp_batch)
last_batch_in_queue = True

elif prepare_dp_attn_flag:
elif prepare_mlp_sync_flag:
batch, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
Expand Down Expand Up @@ -761,13 +756,13 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue

def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
batch, _ = self.prepare_mlp_sync_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
self.prepare_mlp_sync_batch(batch, result)
return batch, result

def get_next_disagg_decode_batch_to_run(
Expand Down
17 changes: 5 additions & 12 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import require_mlp_sync

if TYPE_CHECKING:
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -274,12 +275,8 @@ def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()

# Handle DP attention
if (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch

if batch:
Expand Down Expand Up @@ -312,12 +309,8 @@ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()

# Handle DP attention
if (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider removing get_local_attention_dp_size since it's no longer used after the change.

get_local_attention_dp_size,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Expand Down Expand Up @@ -229,7 +229,7 @@ class CommunicateContext:
process_group_sizes: Dict[ScatterMode, int]
attn_tp_rank: int
attn_tp_size: int
local_attn_dp_size: int
attn_dp_size: int
tp_size: int

def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
Expand All @@ -239,7 +239,7 @@ def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
def init_new(cls):
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
local_attn_dp_size = get_local_attention_dp_size()
attn_dp_size = get_attention_dp_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use attn_dp_size instead of local_attn_dp_size to initialize the CommunicateContext.

        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()
        attn_dp_size = get_attention_dp_size()

tp_size = get_tensor_model_parallel_world_size()
process_group_sizes = {
ScatterMode.SCATTERED: 1,
Expand All @@ -251,7 +251,7 @@ def init_new(cls):
process_group_sizes=process_group_sizes,
attn_tp_rank=attn_tp_rank,
attn_tp_size=attn_tp_size,
local_attn_dp_size=local_attn_dp_size,
attn_dp_size=attn_dp_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use attn_dp_size instead of local_attn_dp_size to initialize the CommunicateContext.

            process_group_sizes=process_group_sizes,
            attn_tp_rank=attn_tp_rank,
            attn_tp_size=attn_tp_size,
            attn_dp_size=attn_dp_size,

tp_size=tp_size,
)

Expand Down Expand Up @@ -385,7 +385,7 @@ def _gather_hidden_states_and_residual(
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
if context.local_attn_dp_size != 1:
if context.attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def disable_dp_size():


def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_local_attention_dp_rank()
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank = get_attention_dp_rank()

if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
attn_tp_all_gather,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
get_local_attention_dp_rank,
get_local_attention_dp_size,
Comment on lines +33 to 36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider removing get_local_attention_dp_rank and get_local_attention_dp_size since they are no longer used after the change.

)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
Expand Down Expand Up @@ -171,7 +171,7 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
return

cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_local_attention_dp_rank()
dp_rank = get_attention_dp_rank()
if dp_rank == 0:
dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0]
Expand Down
19 changes: 10 additions & 9 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
kill_itself_when_parent_died,
point_to_point_pyobj,
pyspy_dump_schedulers,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
set_random_seed,
suppress_other_loggers,
Expand Down Expand Up @@ -1434,9 +1436,8 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
else:
ret = None

# Handle DP attention
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
ret, _ = self.prepare_dp_attn_batch(ret)
if require_mlp_sync(self.server_args):
ret, _ = self.prepare_mlp_sync_batch(ret)

return ret

Expand Down Expand Up @@ -1746,12 +1747,11 @@ def process_batch_result(
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
return self.prepare_mlp_sync_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
Expand All @@ -1760,14 +1760,14 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
)

@staticmethod
def prepare_dp_attn_batch_raw(
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
Expand All @@ -1776,6 +1776,7 @@ def prepare_dp_attn_batch_raw(
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool,
):
# Check if other DP workers have running batches
if local_batch is None:
Expand Down Expand Up @@ -1850,7 +1851,7 @@ def prepare_dp_attn_batch_raw(

if local_batch is not None:
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
if not require_mlp_tp_gather:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else:
Expand Down
Loading
Loading