Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions python/sglang/srt/managers/scheduler_pp_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
import torch.distributed
from tqdm import tqdm

from sglang.srt.disaggregation.base.conn import KVPoll
from sglang.srt.disaggregation.utils import DisaggregationMode, poll_and_all_reduce
Expand Down Expand Up @@ -532,13 +533,11 @@ def profile_and_init_predictor(self: Scheduler):
latencies: List[float] = []

if self.pp_group.is_first_rank:
logger.info("Profiling prefill latency for dynamic chunk sizing...")

# Create requests with different lengths: base_chunk_size // (2**i) for i in range(10)
input_ids_list = []
for i in range(32):
chunk_size = self.chunked_prefill_size - i * (
self.chunked_prefill_size // 32
for i in range(128):
chunk_size = int(
self.chunked_prefill_size * 1.25
- i * (self.chunked_prefill_size * 1.25 // 128)
)
if chunk_size <= 0:
break
Expand All @@ -551,9 +550,13 @@ def profile_and_init_predictor(self: Scheduler):
temperature=0,
max_new_tokens=1,
)

# Create and profile requests
for i, input_ids in enumerate(input_ids_list):
for i, input_ids in enumerate(
tqdm(
input_ids_list,
desc="Profiling prefill latency for dynamic chunking",
)
):
req = Req(
rid=str(i),
origin_input_text="",
Expand Down Expand Up @@ -1338,8 +1341,8 @@ def predict_next_chunk_size(
)
calculated_chunk_size = int(smoothed_chunk_size)

# Align to page_size (round down to nearest multiple)
alignment_size = max(page_size, 1)
# Align to page_size (minimum alignment size is 64)
alignment_size = max(page_size, 64)
dynamic_chunk_size = (calculated_chunk_size // alignment_size) * alignment_size

# Ensure aligned size is at least alignment_size
Expand Down
Loading