Skip to content

Commit

Permalink
adjust schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Sep 18, 2024
1 parent 1678dc5 commit 9cb2583
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None:
raise e


class NoRunningSeqs(Exception):
"""NoRunningSeqs."""
pass


@dataclass
class InferOutput:
"""The output of the model inference."""
Expand Down Expand Up @@ -636,7 +631,11 @@ async def __long_context_single_forward(inputs):
ret['logits'] = ret['logits'][:, last_token_loc]
return ret
else:
return await __long_context_single_forward(inputs)
ret = await __long_context_single_forward(inputs)
if not return_logits and not inputs.is_decoding:
last_token_loc = [-1]
ret['logits'] = ret['logits'][:, last_token_loc]
return ret

def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
logits: torch.Tensor, stopped: torch.Tensor):
Expand Down Expand Up @@ -828,16 +827,15 @@ def __need_logits(seqs: SeqList):
return any(seq.return_logits for seq in seqs)

while True:
is_prefill = await in_que.get()
is_prefill, scheduler_output = await in_que.get()
try:
running = scheduler_output.running
adapters = scheduler_output.adapters
swap_in_map = scheduler_output.swap_in_map
swap_out_map = scheduler_output.swap_out_map
prefill_interval = self.scheduler_config.prefill_interval
schedule_output = self.scheduler.schedule(
is_prefill=is_prefill, prealloc_size=prefill_interval)
running: SeqList = schedule_output.running
adapters = schedule_output.adapters
loop_count = 1 if is_prefill else (prefill_interval - 1)
if len(running) == 0:
raise NoRunningSeqs()
assert len(running) > 0

# create inputs
inputs = self.create_model_inputs(running, adapters,
Expand All @@ -855,8 +853,8 @@ def __need_logits(seqs: SeqList):

await self._async_step_background(
inputs=inputs,
swap_in_map=schedule_output.swap_in_map,
swap_out_map=schedule_output.swap_out_map,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
all_ids=all_ids,
guided_input_ids=guided_input_ids,
sampling_inputs=sampling_inputs,
Expand All @@ -877,6 +875,7 @@ async def _async_loop(self):
Each engine instance would communicate with the engine by queue.
"""
prefill_interval = self.scheduler_config.prefill_interval
in_que = asyncio.Queue()
out_que = asyncio.Queue()
loop_background = asyncio.get_event_loop().create_task(
Expand All @@ -899,9 +898,18 @@ def __send_resps(step_outputs: Dict[int, InferOutput]):
for out in step_outputs.values():
__send_resp(out)

async def __step(prefill: bool):
async def __step():
"""step decoding."""
in_que.put_nowait(prefill)
prefill = self.scheduler.has_waiting()
schedule_output = self.scheduler.schedule(
is_prefill=prefill, prealloc_size=prefill_interval)
# schedule decoding if no valid prefill reqs.
if prefill and len(schedule_output.running) == 0:
prefill = False
schedule_output = self.scheduler.schedule(
is_prefill=prefill, prealloc_size=prefill_interval)

in_que.put_nowait((prefill, schedule_output))
finish = False
while not finish:
if self.req_manager.has_requests():
Expand All @@ -914,8 +922,6 @@ async def __step(prefill: bool):
step_outputs = self._make_infer_outputs(
next_token_ids, logits, stopped)
__send_resps(step_outputs)
except NoRunningSeqs:
break
except Exception as e:
raise e
finally:
Expand All @@ -929,13 +935,7 @@ async def __step(prefill: bool):
await asyncio.sleep(0.01)
continue

# prefill
if self.scheduler.has_waiting():
await __step(True)

# decoding
if self.scheduler.has_running():
await __step(False)
await __step()

async def async_loop(self):
device_manager = get_device_manager()
Expand Down

0 comments on commit 9cb2583

Please sign in to comment.