From ab4a83b25909aa98330b838a224e4fe5c943e483 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 5 Sep 2024 14:30:26 -0700 Subject: [PATCH] Optimize schedule (#1339) --- .../sglang/srt/managers/policy_scheduler.py | 110 +++++++++++++++++- python/sglang/srt/managers/tp_worker.py | 21 +++- 2 files changed, 123 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 04169e80861..3a70bfe5482 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -108,18 +108,24 @@ class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, + running_batch: ScheduleBatch, + new_token_ratio: float, rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache + self.running_batch = running_batch + self.new_token_ratio = new_token_ratio self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens + self.total_tokens = rem_total_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens + self.req_states = None self.can_run_list = [] self.new_inflight_req = None self.log_hit_tokens = 0 @@ -136,16 +142,14 @@ def no_remaining_tokens(self): ) ) - def remove_running_tokens( - self, running_batch: ScheduleBatch, new_token_ratio: float - ): + def remove_running_tokens(self, running_batch: ScheduleBatch): self.rem_total_tokens -= sum( [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), CLIP_MAX_NEW_TOKENS, ) - * new_token_ratio + * self.new_token_ratio for r in running_batch.reqs ] ) @@ -161,7 +165,29 @@ def _prefill_one_req( self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len + def add_inflight_req_ignore_eos(self, req: Req): + truncated = req.extend_input_len > self.rem_chunk_tokens + req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) + req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] + self.can_run_list.append(req) + + self._prefill_one_req( + 0, + req.extend_input_len, + ( + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) + if not truncated + else 0 + ), + ) + + # Return if chunked prefill not finished + return req if truncated else None + def add_inflight_req(self, req: Req): + if req.sampling_params.ignore_eos: + return self.add_inflight_req_ignore_eos(req) + truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -190,7 +216,81 @@ def _lock_node(self, last_node: TreeNode): delta = self.tree_cache.dec_lock_ref(last_node) self.rem_total_tokens += delta + def add_one_req_ignore_eos(self, req: Req): + def get_req_state(r): + new_token_ratio = ( + 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio + ) + tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len( + r.output_ids + ) + tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) + + if tokens_left > 0: + return (tokens_left, tokens_occupied) + + return None + + if self.req_states is None: + self.req_states = [] + if self.running_batch is not None: + for r in self.running_batch.reqs: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + for r in self.can_run_list: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + state = get_req_state(req) + if state is not None: + self.req_states.append(state) + + self.req_states.sort(key=lambda x: x[0]) + else: + state = get_req_state(req) + if state is not None: + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + if tokens_left >= state[0]: + self.req_states.insert(i, state) + break + else: + self.req_states.append(state) + + tokens_freed = 0 + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + decode_steps = ( + self.req_states[i + 1][0] + if i + 1 < len(self.req_states) + else tokens_left + ) + bs = len(self.req_states) - i + if self.total_tokens + tokens_freed - decode_steps * bs <= 0: + return False + tokens_freed += tokens_occupied + + if req.extend_input_len <= self.rem_chunk_tokens: + self.can_run_list.append(req) + self._prefill_one_req( + 0, + req.extend_input_len, + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), + ) + else: + # Chunked prefill + trunc_len = self.rem_chunk_tokens + req.extend_input_len = trunc_len + req.fill_ids = req.fill_ids[:trunc_len] + self.can_run_list.append(req) + self.new_inflight_req = req + self._prefill_one_req(0, trunc_len, 0) + + return True + def add_one_req(self, req: Req): + if req.sampling_params.ignore_eos and self.tree_cache.disable: + return self.add_one_req_ignore_eos(req) + total_tokens = req.extend_input_len + min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS ) @@ -233,4 +333,4 @@ def add_one_req(self, req: Req): self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) - return True + return True and not self.no_remaining_tokens() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8fc03b85991..d914a71c27a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -221,6 +221,7 @@ def __init__( ) self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio_decay = global_config.new_token_ratio_decay + self.do_not_get_new_batch = False def exposed_step(self, recv_reqs: List): try: @@ -253,7 +254,13 @@ def exposed_step(self, recv_reqs: List): @torch.inference_mode() def forward_step(self): - new_batch = self.get_new_prefill_batch() + if self.current_inflight_req is not None: + self.do_not_get_new_batch = False + + new_batch = ( + self.get_new_prefill_batch() if not self.do_not_get_new_batch else None + ) + self.do_not_get_new_batch = False if new_batch is not None: # Run a new prefill batch @@ -409,6 +416,8 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: adder = PrefillAdder( self.tree_cache, + self.running_batch, + self.new_token_ratio, self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, @@ -416,7 +425,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) if self.running_batch is not None: - adder.remove_running_tokens(self.running_batch, self.new_token_ratio) + adder.remove_running_tokens(self.running_batch) has_inflight = self.current_inflight_req is not None if self.current_inflight_req is not None: @@ -428,11 +437,12 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) for req in self.waiting_queue: + if adder.no_remaining_tokens(): + break req.init_next_round_input(None if prefix_computed else self.tree_cache) res = adder.add_one_req(req) if ( not res - or adder.no_remaining_tokens() or running_bs + len(adder.can_run_list) >= self.max_running_requests ): break @@ -700,6 +710,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): next_token_ids = next_token_ids.tolist() # Check finish condition + has_finished = False for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) @@ -712,6 +723,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): if req.finished(): self.tree_cache.cache_finished_req(req) + has_finished = True if req.return_logprob: req.output_token_logprobs.append( @@ -720,6 +732,9 @@ def forward_decode_batch(self, batch: ScheduleBatch): if req.top_logprobs_num > 0: req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + if not has_finished: + self.do_not_get_new_batch = True + self.handle_finished_requests(batch) def handle_finished_requests(self, batch: ScheduleBatch):