Skip to content

Commit

Permalink
[Core] Combine async postprocessor and multi-step (vllm-project#7921)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat authored Aug 29, 2024
1 parent f205c09 commit 3f60f22
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 65 deletions.
10 changes: 6 additions & 4 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
num_scheduler_steps: int, num_prompts: int,
is_async: bool):

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if not is_async:
ms_server_args += ["--disable-async-output-proc"]

if eager_mode:
ms_server_args.append("--enforce-eager")

Expand Down
5 changes: 1 addition & 4 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,10 +1107,7 @@ def schedule(
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []

# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
allow_async_output_proc: bool = self.use_async_output_proc

# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
Expand Down
65 changes: 44 additions & 21 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# skip the scheduler if there are any remaining steps in the seq groups.
Expand All @@ -289,17 +293,27 @@ async def step_async(
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
Expand All @@ -311,9 +325,6 @@ async def step_async(
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
Expand All @@ -339,8 +350,13 @@ async def step_async(
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback[
virtual_engine]
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]

execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step

# Execute the model.
output = await self.model_executor.execute_model_async(
Expand All @@ -350,7 +366,7 @@ async def step_async(
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
Expand All @@ -362,22 +378,25 @@ async def step_async(
seq_group.finish_step()

if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

# Cache results in engine
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
Expand All @@ -390,7 +409,11 @@ async def step_async(
self.do_tracing(scheduler_outputs)

else:
ctx.request_outputs = []
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
Expand Down
Loading

0 comments on commit 3f60f22

Please sign in to comment.