Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
create_grammar_backend,
)
from sglang.srt.disaggregation.decode import (
CLIP_MAX_NEW_TOKEN,
DecodePreallocQueue,
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
Expand Down Expand Up @@ -1685,20 +1686,29 @@ def _get_swa_token_info(self):
)

def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
if self.is_hybrid:
rem_total_tokens = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn:
rem_total_tokens = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else:
rem_total_tokens = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)
else:
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.last_batch.chunked_req is not None:
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
Expand All @@ -1722,7 +1732,29 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch)

new_batch = self.get_new_batch_prefill()
rem_total_tokens -= sum(
[
min(
r.sampling_params.max_new_tokens - len(r.output_ids),
CLIP_MAX_NEW_TOKEN,
)
* self.new_token_ratio
for r in self.running_batch.reqs
]
)
if self.chunked_req and rem_total_tokens > 0:
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)
else:
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
if rem_total_tokens > 0:
new_batch = self.get_new_batch_prefill()
else:
new_batch = None

need_mlp_sync = self.require_mlp_sync
if need_mlp_sync and not self.spec_algorithm.is_none():
Expand Down
Loading