Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d172c82
add pp mixin
merrymercy Oct 19, 2025
437f100
update point_to_point_pyobj
merrymercy Oct 19, 2025
27b78d5
[PP] support async PP
XucSh Oct 20, 2025
269e64b
update
XucSh Oct 20, 2025
14125ae
Merge branch 'main' into Xuchun/pp-dev
ShangmingCai Oct 21, 2025
cad6210
upd
ShangmingCai Oct 21, 2025
53335cc
fix
ShangmingCai Oct 21, 2025
c493200
fix
XucSh Oct 21, 2025
44656b8
fix
XucSh Oct 21, 2025
21df2f2
async run and process
XucSh Oct 21, 2025
8fec316
fix
XucSh Oct 22, 2025
382ce86
fix
XucSh Oct 22, 2025
770b28f
update
XucSh Oct 22, 2025
0dbc810
Merge branch 'main' into Xuchun/pp-dev
XucSh Oct 23, 2025
49a81dc
tiny improvement by delaying commit sync
bluecoffee8 Oct 23, 2025
ac347c5
send req asap
bluecoffee8 Oct 23, 2025
2efd976
cleanup PR
bluecoffee8 Oct 23, 2025
7dbc256
add recv tensor to cuda stream
bluecoffee8 Oct 24, 2025
8b03c94
move sync step forward
bluecoffee8 Oct 24, 2025
b2e95f0
move out of copy stream
bluecoffee8 Oct 24, 2025
49ba293
commit send proxy work before launch batch
bluecoffee8 Oct 24, 2025
5d07c36
Merge branch 'main' into Xuchun/pp-dev
XucSh Oct 27, 2025
4ba6fb8
Merge branch 'main' into Xuchun/pp-dev
XucSh Oct 28, 2025
82d91d9
Merge remote-tracking branch 'origin/main' into pp-pr
XucSh Oct 29, 2025
aa62f6c
Merge remote-tracking branch 'origin/main' into pp-pr
XucSh Nov 5, 2025
60ce684
Update scheduler_pp_mixin.py
zhangxiaolei123456 Nov 13, 2025
60be396
Merge remote-tracking branch 'origin/main' into pp-pr
XucSh Nov 13, 2025
17618e2
Merge pull request #10 from zhangxiaolei123456/pd_pp_dev_zhang
ShangmingCai Nov 14, 2025
d3b26d0
fix
ShangmingCai Nov 14, 2025
30427d6
fix
ShangmingCai Nov 14, 2025
5bef428
Merge branch 'main' into Xuchun/pp-dev
XucSh Nov 18, 2025
6df9b79
Merge branch 'main' into Xuchun/pp-dev
XucSh Nov 21, 2025
026b942
fix
XucSh Nov 24, 2025
673c37f
Merge branch 'main' into Xuchun/pp-dev
XucSh Nov 24, 2025
59c8ef3
fix
XucSh Nov 24, 2025
4015608
add dynamic chunk support
ShangmingCai Nov 25, 2025
368382a
clean
ShangmingCai Nov 25, 2025
d13b07c
remove redundant code
ShangmingCai Nov 26, 2025
7f65270
Merge branch 'main' into Xuchun/pp-dev
XucSh Dec 2, 2025
34faf31
add smooth coeff
ShangmingCai Dec 4, 2025
44452b0
Merge branch 'main' into Xuchun/pp-dev
XucSh Dec 8, 2025
e1fd4d6
Merge branch 'main' into Xuchun/pp-dev
ShangmingCai Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
SWAKVPool,
)
from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj

if TYPE_CHECKING:
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -252,8 +251,6 @@ def pop_bootstrapped(
# if req not in reqs_info_to_check, skip
if req.rid not in rids_to_check:
continue
# Either waiting for input or failed
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed

if poll == KVPoll.Bootstrapping:
continue
Expand Down Expand Up @@ -710,36 +707,3 @@ def send_kv_chunk(
)
return
req.disagg_kv_sender.send(page_indices, state_indices)

def send_pyobj_to_next_stage(self, data):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
point_to_point_pyobj(
data,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
)

def recv_pyobj_from_prev_stage(self):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
data = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
else:
data = None

if self.attn_tp_size != 1:
data = broadcast_pyobj(
data,
self.attn_tp_group.rank,
self.attn_tp_cpu_group,
src=self.attn_tp_group.ranks[0],
)
return data
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class Envs:
SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP = EnvBool(False)
SGLANG_SCHEDULER_MAX_RECV_PER_POLL = EnvInt(-1)
SGLANG_EXPERIMENTAL_CPP_RADIX_TREE = EnvBool(False)
SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR = EnvFloat(0.75)

# Test: pd-disaggregation
SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake")
Expand Down
49 changes: 42 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.common import release_kv_cache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
Expand Down Expand Up @@ -472,6 +472,21 @@ def __init__(
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)

self.enable_dynamic_chunking = (
server_args.enable_dynamic_chunking and self.pp_size > 1
)

# Init the dynamic chunking predictor for PP
if self.enable_dynamic_chunking:
try:
self.profile_and_init_predictor()
except Exception as e:
logger.warning(
f"[PP Dynamic Chunk] Failed to profile prefill latency: {e}. "
"Dynamic chunking will be disabled."
)
self.enable_dynamic_chunking = False

# Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
Expand Down Expand Up @@ -934,8 +949,7 @@ def init_disaggregation(self):

def init_overlap(self):
self.future_map = None

if not self.enable_overlap:
if not self.enable_overlap and self.pp_size == 1:
return

self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
Expand All @@ -947,6 +961,9 @@ def init_overlap(self):
self.device
).stream(self.copy_stream)

if not self.enable_overlap:
return

self.future_map = FutureMap(
self.max_running_requests,
self.chunked_prefill_size,
Expand Down Expand Up @@ -1108,7 +1125,7 @@ def recv_requests(
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.world_group.cpu_group,
(self.pp_rank - 1) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
Expand Down Expand Up @@ -1766,6 +1783,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# in the waiting queue.
return None

# Determine chunked_prefill_size for this batch
chunked_prefill_size = self.chunked_prefill_size
if self.chunked_req is not None:
self.chunked_req.init_next_round_input()
if self.enable_dynamic_chunking:
history_len = len(self.chunked_req.prefix_indices)
dynamic_size = self.predict_next_chunk_size(history_len)
if dynamic_size is not None:
chunked_prefill_size = dynamic_size

# Prefill policy
adder = PrefillAdder(
self.page_size,
Expand All @@ -1774,7 +1801,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.running_batch,
self.new_token_ratio,
self.max_prefill_tokens,
self.chunked_prefill_size,
chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold,
)
Expand Down Expand Up @@ -1966,7 +1993,9 @@ def update_cache_from_scheduler(
pass

def run_batch(
self, batch: ScheduleBatch
self,
batch: ScheduleBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
self.forward_ct += 1
Expand Down Expand Up @@ -2014,6 +2043,7 @@ def run_batch(
self.future_map.resolve_future(model_worker_batch)
batch_result = self.model_worker.forward_batch_generation(
model_worker_batch
# here pp is not compatible with overlap
)
# FIXME(lsyin): maybe move this to forward_batch_generation
batch_result.copy_done = torch.get_device_module(
Expand Down Expand Up @@ -2047,8 +2077,13 @@ def run_batch(
batch_result = self.tp_worker.forward_batch_split_prefill(batch)
future_indices_or_next_token_ids = batch_result.next_token_ids
else:
kwargs = (
{"pp_proxy_tensors": pp_proxy_tensors}
if self.spec_algorithm.is_none()
else {}
)
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
batch_or_worker_batch, **kwargs
)
future_indices_or_next_token_ids = batch_result.next_token_ids
self.update_cache_from_scheduler(batch, batch_result)
Expand Down
Loading
Loading