From 9cb25830bf61a86fbf4bf4a4de801ecea4a38d31 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 18 Sep 2024 18:14:49 +0800 Subject: [PATCH] adjust schedule --- lmdeploy/pytorch/engine/engine.py | 52 +++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b9919efb3b..9bc458806c 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -43,11 +43,6 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None: raise e -class NoRunningSeqs(Exception): - """NoRunningSeqs.""" - pass - - @dataclass class InferOutput: """The output of the model inference.""" @@ -636,7 +631,11 @@ async def __long_context_single_forward(inputs): ret['logits'] = ret['logits'][:, last_token_loc] return ret else: - return await __long_context_single_forward(inputs) + ret = await __long_context_single_forward(inputs) + if not return_logits and not inputs.is_decoding: + last_token_loc = [-1] + ret['logits'] = ret['logits'][:, last_token_loc] + return ret def _make_infer_outputs(self, next_token_ids: torch.LongTensor, logits: torch.Tensor, stopped: torch.Tensor): @@ -828,16 +827,15 @@ def __need_logits(seqs: SeqList): return any(seq.return_logits for seq in seqs) while True: - is_prefill = await in_que.get() + is_prefill, scheduler_output = await in_que.get() try: + running = scheduler_output.running + adapters = scheduler_output.adapters + swap_in_map = scheduler_output.swap_in_map + swap_out_map = scheduler_output.swap_out_map prefill_interval = self.scheduler_config.prefill_interval - schedule_output = self.scheduler.schedule( - is_prefill=is_prefill, prealloc_size=prefill_interval) - running: SeqList = schedule_output.running - adapters = schedule_output.adapters loop_count = 1 if is_prefill else (prefill_interval - 1) - if len(running) == 0: - raise NoRunningSeqs() + assert len(running) > 0 # create inputs inputs = self.create_model_inputs(running, adapters, @@ -855,8 +853,8 @@ def __need_logits(seqs: SeqList): await self._async_step_background( inputs=inputs, - swap_in_map=schedule_output.swap_in_map, - swap_out_map=schedule_output.swap_out_map, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map, all_ids=all_ids, guided_input_ids=guided_input_ids, sampling_inputs=sampling_inputs, @@ -877,6 +875,7 @@ async def _async_loop(self): Each engine instance would communicate with the engine by queue. """ + prefill_interval = self.scheduler_config.prefill_interval in_que = asyncio.Queue() out_que = asyncio.Queue() loop_background = asyncio.get_event_loop().create_task( @@ -899,9 +898,18 @@ def __send_resps(step_outputs: Dict[int, InferOutput]): for out in step_outputs.values(): __send_resp(out) - async def __step(prefill: bool): + async def __step(): """step decoding.""" - in_que.put_nowait(prefill) + prefill = self.scheduler.has_waiting() + schedule_output = self.scheduler.schedule( + is_prefill=prefill, prealloc_size=prefill_interval) + # schedule decoding if no valid prefill reqs. + if prefill and len(schedule_output.running) == 0: + prefill = False + schedule_output = self.scheduler.schedule( + is_prefill=prefill, prealloc_size=prefill_interval) + + in_que.put_nowait((prefill, schedule_output)) finish = False while not finish: if self.req_manager.has_requests(): @@ -914,8 +922,6 @@ async def __step(prefill: bool): step_outputs = self._make_infer_outputs( next_token_ids, logits, stopped) __send_resps(step_outputs) - except NoRunningSeqs: - break except Exception as e: raise e finally: @@ -929,13 +935,7 @@ async def __step(prefill: bool): await asyncio.sleep(0.01) continue - # prefill - if self.scheduler.has_waiting(): - await __step(True) - - # decoding - if self.scheduler.has_running(): - await __step(False) + await __step() async def async_loop(self): device_manager = get_device_manager()