diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cfbde332fcf4..3ea790ddd6b1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -41,6 +41,7 @@ create_grammar_backend, ) from sglang.srt.disaggregation.decode import ( + CLIP_MAX_NEW_TOKEN, DecodePreallocQueue, DecodeTransferQueue, SchedulerDisaggregationDecodeMixin, @@ -1685,20 +1686,29 @@ def _get_swa_token_info(self): ) def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: + if self.is_hybrid: + rem_total_tokens = min( + self.token_to_kv_pool_allocator.full_available_size() + + self.tree_cache.full_evictable_size(), + self.token_to_kv_pool_allocator.swa_available_size() + + self.tree_cache.swa_evictable_size(), + ) + elif self.is_hybrid_gdn: + rem_total_tokens = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.full_evictable_size() + ) + else: + rem_total_tokens = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.evictable_size() + ) # Merge the prefill batch into the running batch chunked_req_to_exclude = set() if self.chunked_req: # Move the chunked request out of the batch so that we can merge # only finished requests to running_batch. chunked_req_to_exclude.add(self.chunked_req) - self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) - # chunked request keeps its rid but will get a new req_pool_idx - if self.tp_worker.model_runner.mambaish_config is not None: - self.req_to_token_pool.free( - self.chunked_req.req_pool_idx, free_mamba_cache=False - ) - else: - self.req_to_token_pool.free(self.chunked_req.req_pool_idx) if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.last_batch.chunked_req is not None: # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. @@ -1722,7 +1732,29 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge running_batch with prefill batch self.running_batch.merge_batch(self.last_batch) - new_batch = self.get_new_batch_prefill() + rem_total_tokens -= sum( + [ + min( + r.sampling_params.max_new_tokens - len(r.output_ids), + CLIP_MAX_NEW_TOKEN, + ) + * self.new_token_ratio + for r in self.running_batch.reqs + ] + ) + if self.chunked_req and rem_total_tokens > 0: + self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) + # chunked request keeps its rid but will get a new req_pool_idx + if self.tp_worker.model_runner.mambaish_config is not None: + self.req_to_token_pool.free( + self.chunked_req.req_pool_idx, free_mamba_cache=False + ) + else: + self.req_to_token_pool.free(self.chunked_req.req_pool_idx) + if rem_total_tokens > 0: + new_batch = self.get_new_batch_prefill() + else: + new_batch = None need_mlp_sync = self.require_mlp_sync if need_mlp_sync and not self.spec_algorithm.is_none():