-
Notifications
You must be signed in to change notification settings - Fork 492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adjust schedule to improve TTFT in pytorch engine #2477
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -43,11 +43,6 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None: | |||||||
raise e | ||||||||
|
||||||||
|
||||||||
class NoRunningSeqs(Exception): | ||||||||
"""NoRunningSeqs.""" | ||||||||
pass | ||||||||
|
||||||||
|
||||||||
@dataclass | ||||||||
class InferOutput: | ||||||||
"""The output of the model inference.""" | ||||||||
|
@@ -636,7 +631,11 @@ async def __long_context_single_forward(inputs): | |||||||
ret['logits'] = ret['logits'][:, last_token_loc] | ||||||||
return ret | ||||||||
else: | ||||||||
return await __long_context_single_forward(inputs) | ||||||||
ret = await __long_context_single_forward(inputs) | ||||||||
if not return_logits and not inputs.is_decoding: | ||||||||
last_token_loc = [-1] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it influence the following computation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. __long_context_single_forward is used for very long prefill. Only last output is required. |
||||||||
ret['logits'] = ret['logits'][:, last_token_loc] | ||||||||
return ret | ||||||||
|
||||||||
def _make_infer_outputs(self, next_token_ids: torch.LongTensor, | ||||||||
logits: torch.Tensor, stopped: torch.Tensor): | ||||||||
|
@@ -828,16 +827,15 @@ def __need_logits(seqs: SeqList): | |||||||
return any(seq.return_logits for seq in seqs) | ||||||||
|
||||||||
while True: | ||||||||
is_prefill = await in_que.get() | ||||||||
is_prefill, scheduler_output = await in_que.get() | ||||||||
try: | ||||||||
running = scheduler_output.running | ||||||||
adapters = scheduler_output.adapters | ||||||||
swap_in_map = scheduler_output.swap_in_map | ||||||||
swap_out_map = scheduler_output.swap_out_map | ||||||||
prefill_interval = self.scheduler_config.prefill_interval | ||||||||
schedule_output = self.scheduler.schedule( | ||||||||
is_prefill=is_prefill, prealloc_size=prefill_interval) | ||||||||
running: SeqList = schedule_output.running | ||||||||
adapters = schedule_output.adapters | ||||||||
loop_count = 1 if is_prefill else (prefill_interval - 1) | ||||||||
if len(running) == 0: | ||||||||
raise NoRunningSeqs() | ||||||||
assert len(running) > 0 | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do we make sure it is True? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lmdeploy/lmdeploy/pytorch/engine/engine.py Lines 934 to 936 in 9cb2583
Empty requests would be skipped. |
||||||||
|
||||||||
# create inputs | ||||||||
inputs = self.create_model_inputs(running, adapters, | ||||||||
|
@@ -855,8 +853,8 @@ def __need_logits(seqs: SeqList): | |||||||
|
||||||||
await self._async_step_background( | ||||||||
inputs=inputs, | ||||||||
swap_in_map=schedule_output.swap_in_map, | ||||||||
swap_out_map=schedule_output.swap_out_map, | ||||||||
swap_in_map=swap_in_map, | ||||||||
swap_out_map=swap_out_map, | ||||||||
all_ids=all_ids, | ||||||||
guided_input_ids=guided_input_ids, | ||||||||
sampling_inputs=sampling_inputs, | ||||||||
|
@@ -877,6 +875,7 @@ async def _async_loop(self): | |||||||
|
||||||||
Each engine instance would communicate with the engine by queue. | ||||||||
""" | ||||||||
prefill_interval = self.scheduler_config.prefill_interval | ||||||||
in_que = asyncio.Queue() | ||||||||
out_que = asyncio.Queue() | ||||||||
loop_background = asyncio.get_event_loop().create_task( | ||||||||
|
@@ -899,9 +898,18 @@ def __send_resps(step_outputs: Dict[int, InferOutput]): | |||||||
for out in step_outputs.values(): | ||||||||
__send_resp(out) | ||||||||
|
||||||||
async def __step(prefill: bool): | ||||||||
async def __step(): | ||||||||
"""step decoding.""" | ||||||||
in_que.put_nowait(prefill) | ||||||||
prefill = self.scheduler.has_waiting() | ||||||||
schedule_output = self.scheduler.schedule( | ||||||||
is_prefill=prefill, prealloc_size=prefill_interval) | ||||||||
# schedule decoding if no valid prefill reqs. | ||||||||
if prefill and len(schedule_output.running) == 0: | ||||||||
prefill = False | ||||||||
schedule_output = self.scheduler.schedule( | ||||||||
is_prefill=prefill, prealloc_size=prefill_interval) | ||||||||
|
||||||||
in_que.put_nowait((prefill, schedule_output)) | ||||||||
finish = False | ||||||||
while not finish: | ||||||||
if self.req_manager.has_requests(): | ||||||||
|
@@ -914,8 +922,6 @@ async def __step(prefill: bool): | |||||||
step_outputs = self._make_infer_outputs( | ||||||||
next_token_ids, logits, stopped) | ||||||||
__send_resps(step_outputs) | ||||||||
except NoRunningSeqs: | ||||||||
break | ||||||||
except Exception as e: | ||||||||
raise e | ||||||||
finally: | ||||||||
|
@@ -929,13 +935,7 @@ async def __step(prefill: bool): | |||||||
await asyncio.sleep(0.01) | ||||||||
continue | ||||||||
|
||||||||
# prefill | ||||||||
if self.scheduler.has_waiting(): | ||||||||
await __step(True) | ||||||||
|
||||||||
# decoding | ||||||||
if self.scheduler.has_running(): | ||||||||
await __step(False) | ||||||||
await __step() | ||||||||
|
||||||||
async def async_loop(self): | ||||||||
device_manager = get_device_manager() | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what benefits can we get from this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmdeploy/lmdeploy/pytorch/engine/engine.py
Lines 518 to 525 in 9cb2583