Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjust schedule to improve TTFT in pytorch engine #2477

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None:
raise e


class NoRunningSeqs(Exception):
"""NoRunningSeqs."""
pass


@dataclass
class InferOutput:
"""The output of the model inference."""
Expand Down Expand Up @@ -636,7 +631,11 @@ async def __long_context_single_forward(inputs):
ret['logits'] = ret['logits'][:, last_token_loc]
return ret
else:
return await __long_context_single_forward(inputs)
ret = await __long_context_single_forward(inputs)
if not return_logits and not inputs.is_decoding:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what benefits can we get from this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __get_last_logits():
"""get last logits."""
seq_length = inputs.seq_length
if len(seq_length) == logits.size(0):
return logits
last_idx = seq_length.cumsum(-1) - 1
return logits[last_idx, :]

last_token_loc = [-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it influence the following computation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__long_context_single_forward is used for very long prefill. Only last output is required.

ret['logits'] = ret['logits'][:, last_token_loc]
return ret

def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
logits: torch.Tensor, stopped: torch.Tensor):
Expand Down Expand Up @@ -828,16 +827,15 @@ def __need_logits(seqs: SeqList):
return any(seq.return_logits for seq in seqs)

while True:
is_prefill = await in_que.get()
is_prefill, scheduler_output = await in_que.get()
try:
running = scheduler_output.running
adapters = scheduler_output.adapters
swap_in_map = scheduler_output.swap_in_map
swap_out_map = scheduler_output.swap_out_map
prefill_interval = self.scheduler_config.prefill_interval
schedule_output = self.scheduler.schedule(
is_prefill=is_prefill, prealloc_size=prefill_interval)
running: SeqList = schedule_output.running
adapters = schedule_output.adapters
loop_count = 1 if is_prefill else (prefill_interval - 1)
if len(running) == 0:
raise NoRunningSeqs()
assert len(running) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we make sure it is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not self.scheduler.has_unfinished():
await asyncio.sleep(0.01)
continue

Empty requests would be skipped.


# create inputs
inputs = self.create_model_inputs(running, adapters,
Expand All @@ -855,8 +853,8 @@ def __need_logits(seqs: SeqList):

await self._async_step_background(
inputs=inputs,
swap_in_map=schedule_output.swap_in_map,
swap_out_map=schedule_output.swap_out_map,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
all_ids=all_ids,
guided_input_ids=guided_input_ids,
sampling_inputs=sampling_inputs,
Expand All @@ -877,6 +875,7 @@ async def _async_loop(self):

Each engine instance would communicate with the engine by queue.
"""
prefill_interval = self.scheduler_config.prefill_interval
in_que = asyncio.Queue()
out_que = asyncio.Queue()
loop_background = asyncio.get_event_loop().create_task(
Expand All @@ -899,9 +898,18 @@ def __send_resps(step_outputs: Dict[int, InferOutput]):
for out in step_outputs.values():
__send_resp(out)

async def __step(prefill: bool):
async def __step():
"""step decoding."""
in_que.put_nowait(prefill)
prefill = self.scheduler.has_waiting()
schedule_output = self.scheduler.schedule(
is_prefill=prefill, prealloc_size=prefill_interval)
# schedule decoding if no valid prefill reqs.
if prefill and len(schedule_output.running) == 0:
prefill = False
schedule_output = self.scheduler.schedule(
is_prefill=prefill, prealloc_size=prefill_interval)

in_que.put_nowait((prefill, schedule_output))
finish = False
while not finish:
if self.req_manager.has_requests():
Expand All @@ -914,8 +922,6 @@ async def __step(prefill: bool):
step_outputs = self._make_infer_outputs(
next_token_ids, logits, stopped)
__send_resps(step_outputs)
except NoRunningSeqs:
break
except Exception as e:
raise e
finally:
Expand All @@ -929,13 +935,7 @@ async def __step(prefill: bool):
await asyncio.sleep(0.01)
continue

# prefill
if self.scheduler.has_waiting():
await __step(True)

# decoding
if self.scheduler.has_running():
await __step(False)
await __step()

async def async_loop(self):
device_manager = get_device_manager()
Expand Down
Loading