-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Combine async postprocessor and multi-step #7921
Changes from 5 commits
7b420d1
ea2c989
dcc4824
b1c78f1
3e2b7a8
c29e4da
e26b18b
942bc12
dbbde98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,7 +91,8 @@ class SchedulerOutputState: | |
|
||
@dataclass | ||
class SchedulerContext: | ||
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata], | ||
output_queue: Deque[Tuple[Optional[List[SamplerOutput]], | ||
List[SequenceGroupMetadata], | ||
SchedulerOutputs]] = field( | ||
default_factory=lambda: deque()) | ||
|
||
|
@@ -432,6 +433,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: | |
for v_id in range(self.parallel_config.pipeline_parallel_size) | ||
] | ||
|
||
self.async_callback_multi_step = [ | ||
functools.partial(self._process_model_outputs, | ||
virtual_engine=v_id, | ||
is_async=False) | ||
for v_id in range(self.parallel_config.pipeline_parallel_size) | ||
] | ||
|
||
def _initialize_kv_caches(self) -> None: | ||
"""Initialize the KV cache in the worker(s). | ||
|
||
|
@@ -1240,8 +1248,11 @@ def _process_sequence_group_outputs( | |
|
||
return | ||
|
||
def _process_model_outputs(self, virtual_engine: int, | ||
is_async: bool) -> None: | ||
def _process_model_outputs(self, | ||
virtual_engine: int, | ||
is_async: bool, | ||
sampler_output: Optional[SamplerOutput] = None, | ||
is_last_output: bool = False) -> None: | ||
"""Apply the model output to the sequences in the scheduled seq groups. | ||
|
||
virtual_engine: The engine id to operate on | ||
|
@@ -1255,13 +1266,25 @@ def _process_model_outputs(self, virtual_engine: int, | |
""" | ||
now = time.time() | ||
|
||
is_multi_step = sampler_output is not None | ||
|
||
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] | ||
|
||
if len(ctx.output_queue) == 0: | ||
return None | ||
|
||
(outputs, seq_group_metadata_list, | ||
scheduler_outputs) = ctx.output_queue.popleft() | ||
if is_multi_step: | ||
# Async + multi-step case | ||
(outputs, seq_group_metadata_list, | ||
scheduler_outputs) = ctx.output_queue[0] | ||
assert outputs is None | ||
outputs = [sampler_output] | ||
else: | ||
# Async standard case | ||
(outputs, seq_group_metadata_list, | ||
scheduler_outputs) = ctx.output_queue.popleft() | ||
|
||
assert outputs is not None | ||
|
||
# Sanity check | ||
assert len(seq_group_metadata_list) == len( | ||
|
@@ -1320,15 +1343,19 @@ def _process_model_outputs(self, virtual_engine: int, | |
self.output_processor.process_outputs(seq_group, output, | ||
is_async) | ||
|
||
# Free the finished sequence groups. | ||
# For async + multi-step, free finished seqs and create outputs | ||
# only on the final step. | ||
if is_multi_step and not is_last_output: | ||
return | ||
|
||
for scheduler in self.scheduler: | ||
scheduler.free_finished_seq_groups() | ||
|
||
# Create the outputs. | ||
for i, _ in enumerate(seq_group_metadata_list): | ||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] | ||
|
||
if i in finished_before: | ||
if not is_multi_step and i in finished_before: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually not since multi-step is using is_last_output to indicate if to run this code or not. Added docstring below to explain better. |
||
continue # Avoids double processing | ||
|
||
seq_group = scheduled_seq_group.seq_group | ||
|
@@ -1342,7 +1369,11 @@ def _process_model_outputs(self, virtual_engine: int, | |
request_output = RequestOutputFactory.create(seq_group) | ||
ctx.request_outputs.append(request_output) | ||
|
||
if is_async: | ||
# For async + multi-step, do stats only on the last output. | ||
# Otherwise, do stats if the execution is async | ||
do_stats = is_multi_step or is_async | ||
|
||
if do_stats: | ||
# Log stats. | ||
self.do_log_stats(scheduler_outputs, outputs, finished_before) | ||
|
||
|
@@ -1437,7 +1468,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
"as performance will be severely degraded otherwise.") | ||
|
||
# For llm_engine, there is no pipeline parallel support, so the engine | ||
# used is always 0 | ||
# used is always 0. | ||
virtual_engine = 0 | ||
|
||
# These are cached outputs from previous iterations. None if on first | ||
|
@@ -1447,6 +1478,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
scheduler_outputs = cached_outputs.scheduler_outputs | ||
allow_async_output_proc = cached_outputs.allow_async_output_proc | ||
|
||
# Detect async + multi-step | ||
use_async_and_multi_step = (self.scheduler_config.is_multi_step | ||
and allow_async_output_proc) | ||
|
||
ctx = self.scheduler_contexts[virtual_engine] | ||
|
||
# Skip the scheduler if there are any remaining steps in the seq groups. | ||
|
@@ -1462,11 +1497,22 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
allow_async_output_proc | ||
) = self.scheduler[virtual_engine].schedule() | ||
|
||
# Detect async + multi-step | ||
use_async_and_multi_step = (self.scheduler_config.is_multi_step | ||
and allow_async_output_proc) | ||
|
||
# Maybe switch from async mode to sync mode | ||
if not allow_async_output_proc and len(ctx.output_queue) > 0: | ||
self._process_model_outputs(virtual_engine=virtual_engine, | ||
is_async=True) | ||
|
||
# For async + multi-step, init the queue | ||
if use_async_and_multi_step: | ||
assert len(ctx.output_queue) == 0 | ||
assert seq_group_metadata_list is not None | ||
ctx.output_queue.append( | ||
(None, seq_group_metadata_list, scheduler_outputs)) | ||
|
||
if (self.scheduler_config.is_multi_step | ||
and scheduler_outputs.num_lookahead_slots > 0): | ||
# cache the scheduler outputs for the next iteration if we have | ||
|
@@ -1478,9 +1524,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
assert seq_group_metadata_list is not None | ||
assert scheduler_outputs is not None | ||
|
||
assert not (self.scheduler_config.is_multi_step and \ | ||
allow_async_output_proc) | ||
|
||
if not scheduler_outputs.is_empty(): | ||
finished_requests_ids = self.scheduler[ | ||
virtual_engine].get_and_reset_finished_requests_ids() | ||
|
@@ -1505,8 +1548,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
last_sampled_token_ids=last_sampled_token_ids) | ||
|
||
if allow_async_output_proc: | ||
execute_model_req.async_callback = self.async_callback[ | ||
virtual_engine] | ||
async_callback = self.async_callback_multi_step[ | ||
virtual_engine] if use_async_and_multi_step \ | ||
else self.async_callback[virtual_engine] | ||
|
||
execute_model_req.async_callback = async_callback | ||
execute_model_req.use_async_and_multi_step = \ | ||
use_async_and_multi_step | ||
|
||
output = self.model_executor.execute_model( | ||
execute_model_req=execute_model_req) | ||
|
@@ -1518,7 +1566,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
else: | ||
# Nothing scheduled => If there is pending async postprocessor, | ||
# then finish it here. | ||
if len(ctx.output_queue) > 0: | ||
if not use_async_and_multi_step and len(ctx.output_queue) > 0: | ||
assert not self.scheduler_config.is_multi_step | ||
self._process_model_outputs(virtual_engine=virtual_engine, | ||
is_async=True) | ||
|
@@ -1535,18 +1583,23 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
if self.scheduler_config.is_multi_step: | ||
self.cached_scheduler_outputs[0] = SchedulerOutputState() | ||
|
||
# Add results to the output_queue | ||
# (for async or non-async postprocessing) | ||
ctx.output_queue.append( | ||
(output, seq_group_metadata_list, scheduler_outputs)) | ||
if use_async_and_multi_step: | ||
# For async + multi-step, clear the queue | ||
ctx.output_queue.clear() | ||
else: | ||
# Add results to the output_queue | ||
# (for async or non-async postprocessing) | ||
ctx.output_queue.append( | ||
(output, seq_group_metadata_list, scheduler_outputs)) | ||
|
||
if output and allow_async_output_proc: | ||
assert len(output) == 1, ("Multi step decoding does not work " | ||
"with async output processing.") | ||
if output and allow_async_output_proc: | ||
assert len(output) == 1, ( | ||
"Multi step decoding does not work " | ||
"with async output processing.") | ||
|
||
self._advance_to_next_step( | ||
output[0], seq_group_metadata_list, | ||
scheduler_outputs.scheduled_seq_groups) | ||
self._advance_to_next_step( | ||
output[0], seq_group_metadata_list, | ||
scheduler_outputs.scheduled_seq_groups) | ||
|
||
# Check if need to run the usual non-async path | ||
if not allow_async_output_proc: | ||
|
@@ -1560,7 +1613,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
self.do_tracing(scheduler_outputs) | ||
else: | ||
# Multi-step case | ||
ctx.request_outputs = [] | ||
if use_async_and_multi_step: | ||
return [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? Won't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Async postprocessor may modify ctx.request_outputs at each step, so I did not want to touch it in the middle of multi-steps running. |
||
else: | ||
ctx.request_outputs = [] | ||
|
||
if not self.has_unfinished_requests(): | ||
# Drain async postprocessor (if exists) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: docstring for
sampler_output
andis_last_output
can be useful hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, added