From 7b420d1a7da66d359da1e6c02d5f46b463277e88 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 27 Aug 2024 13:26:26 +0000 Subject: [PATCH 1/9] Combine async postprocessor and multi-step execution --- .../multi_step/test_correctness_async_llm.py | 10 +- vllm/core/scheduler.py | 5 +- vllm/engine/async_llm_engine.py | 58 +++++++---- vllm/engine/llm_engine.py | 96 ++++++++++++++----- vllm/sequence.py | 10 +- vllm/worker/model_runner.py | 9 +- vllm/worker/multi_step_model_runner.py | 79 ++++++++++++--- vllm/worker/multi_step_worker.py | 8 ++ 8 files changed, 207 insertions(+), 68 deletions(-) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index ad99d70d7417c..ac04be3d9a689 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio async def test_multi_step(example_prompts, model: str, tp_size: int, pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int): + num_scheduler_steps: int, num_prompts: int, + is_async: bool): prompts = example_prompts if len(prompts) < num_prompts: @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] - # Disable output proc callback as its not supported - # with multi-step right now - ms_server_args += ["--disable-async-output-proc"] + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + if eager_mode: ms_server_args.append("--enforce-eager") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 51fde6e4eb7a3..4c2f715820317 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1107,10 +1107,7 @@ def schedule( if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] - # TODO: Combine multi-step and async postprocessor - allow_async_output_proc: bool = ( - self.use_async_output_proc - and not self.scheduler_config.is_multi_step) + allow_async_output_proc: bool = self.use_async_output_proc # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37696bf1d9dc9..2d635be950164 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -279,6 +279,10 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # skip the scheduler if there are any remaining steps in the seq groups. @@ -289,17 +293,27 @@ async def step_async( # Clear outputs on scheduler iteration start ctx.request_outputs.clear() + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # If current scheduler iteration has no async postprocessor, - # then we need first to drain the pending async postprocessor - # before moving forward + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -311,9 +325,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -341,6 +352,8 @@ async def step_async( if allow_async_output_proc: execute_model_req.async_callback = self.async_callback[ virtual_engine] + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step # Execute the model. output = await self.model_executor.execute_model_async( @@ -350,7 +363,7 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -362,22 +375,25 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(virtual_engine=virtual_engine, @@ -390,7 +406,11 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - ctx.request_outputs = [] + # Multi-step case + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a6de8817946cc..1c4d747aa5aff 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -39,9 +39,10 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) +from vllm.sequence import (AsyncCallbackData, EmbeddingSequenceGroupOutput, + ExecuteModelRequest, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -1240,8 +1241,11 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, virtual_engine: int, - is_async: bool) -> None: + def _process_model_outputs(self, + virtual_engine: int, + is_async: bool, + sampler_output: Optional[SamplerOutput] = None, + is_last_output: bool = False) -> None: """Apply the model output to the sequences in the scheduled seq groups. virtual_engine: The engine id to operate on @@ -1255,13 +1259,25 @@ def _process_model_outputs(self, virtual_engine: int, """ now = time.time() + is_multi_step = sampler_output is not None + ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] if len(ctx.output_queue) == 0: return None - (outputs, seq_group_metadata_list, - scheduler_outputs) = ctx.output_queue.popleft() + if is_multi_step: + # Async + multi-step case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue[0] + assert outputs is None + outputs = [sampler_output] + else: + # Async standard case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue.popleft() + + assert outputs is not None # Sanity check assert len(seq_group_metadata_list) == len( @@ -1320,7 +1336,11 @@ def _process_model_outputs(self, virtual_engine: int, self.output_processor.process_outputs(seq_group, output, is_async) - # Free the finished sequence groups. + # For async + multi-step, free finished seqs and create outputs + # only on the final step. + if is_multi_step and not is_last_output: + return + for scheduler in self.scheduler: scheduler.free_finished_seq_groups() @@ -1328,7 +1348,7 @@ def _process_model_outputs(self, virtual_engine: int, for i, _ in enumerate(seq_group_metadata_list): scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - if i in finished_before: + if not is_multi_step and i in finished_before: continue # Avoids double processing seq_group = scheduled_seq_group.seq_group @@ -1342,7 +1362,11 @@ def _process_model_outputs(self, virtual_engine: int, request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) - if is_async: + # For async + multi-step, do stats only on the last output. + # Otherwise, do stats if the execution is async + do_stats = is_multi_step or is_async + + if do_stats: # Log stats. self.do_log_stats(scheduler_outputs, outputs, finished_before) @@ -1447,6 +1471,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # Skip the scheduler if there are any remaining steps in the seq groups. @@ -1462,11 +1490,22 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -1478,9 +1517,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -1507,6 +1543,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if allow_async_output_proc: execute_model_req.async_callback = self.async_callback[ virtual_engine] + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step output = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -1518,7 +1556,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -1535,18 +1573,23 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() - # Add results to the output_queue - # (for async or non-async postprocessing) - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + # Add results to the output_queue + # (for async or non-async postprocessing) + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len(output) == 1, ("Multi step decoding does not work " - "with async output processing.") + if output and allow_async_output_proc: + assert len(output) == 1, ( + "Multi step decoding does not work " + "with async output processing.") - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path if not allow_async_output_proc: @@ -1560,7 +1603,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.do_tracing(scheduler_outputs) else: # Multi-step case - ctx.request_outputs = [] + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/sequence.py b/vllm/sequence.py index 3125acc6fd535..a36de2da9616d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1262,6 +1262,12 @@ def expand_with_bonus_tokens( [self.hidden_states, self.second_last_token_hidden_states])[index] +@dataclass +class AsyncCallbackData: + func: Callable + kw_args: Dict[str, Any] + + class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -1295,6 +1301,7 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + use_async_and_multi_step: bool = False @property def is_first_multi_step(self) -> bool: @@ -1341,4 +1348,5 @@ def clone( finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) + async_callback=self.async_callback, + use_async_and_multi_step=self.use_async_and_multi_step) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b287a5d27157..c4bdfce2d6ce4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -41,8 +41,8 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import (AsyncCallbackData, IntermediateTensors, + SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, supports_dynamo) @@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None + use_async_and_multi_step: bool = False def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 521205eca05af..91bea68a9528e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -13,9 +14,9 @@ from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) +from vllm.sequence import (AsyncCallbackData, CompletionSequenceGroupOutput, + IntermediateTensors, Logprob, SamplerOutput, + SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -215,6 +216,49 @@ def prepare_model_input( ) return model_input + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: AsyncCallbackData): + output_proc_fn = output_proc_callback.func + output_proc_kw_args = output_proc_callback.kw_args + virtual_engine = output_proc_kw_args["virtual_engine"] + + for model_output in model_input.cached_outputs: + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + output_proc_fn(virtual_engine=virtual_engine, + is_async=False, + sampler_output=model_output.sampler_output) + + def _final_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: AsyncCallbackData): + assert model_input.frozen_model_input is not None + + if output_proc_callback is not None: + output_proc_fn = output_proc_callback.func + output_proc_kw_args = output_proc_callback.kw_args + virtual_engine = output_proc_kw_args["virtual_engine"] + + outputs = [] + for output_id in range(len(model_input.cached_outputs)): + is_last_output = output_id == len(model_input.cached_outputs) - 1 + + output = model_input.cached_outputs[output_id] + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + if model_input.frozen_model_input.use_async_and_multi_step: + output_proc_fn(virtual_engine=virtual_engine, + is_async=False, + sampler_output=output.sampler_output, + is_last_output=is_last_output) + + outputs.append(output.sampler_output) + + return outputs + @torch.inference_mode() def execute_model( self, @@ -271,6 +315,20 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) + output_proc_callback = None + if frozen_model_input.use_async_and_multi_step: + output_proc_callback = frozen_model_input.async_callback + async_callback = AsyncCallbackData( + self._async_process_outputs, { + "model_input": model_input, + "output_proc_callback": output_proc_callback + }) + + frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + assert frozen_model_input is not None + # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, kv_caches, @@ -301,9 +359,11 @@ def execute_model( output[0].logprobs = None # Pythonize the output if CPU is ahead and the previous step is # ready. - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + if not frozen_model_input.use_async_and_multi_step: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) model_input.current_step += 1 @@ -316,11 +376,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = [] - for output in model_input.cached_outputs: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - outputs.append(output.sampler_output) + outputs = self._final_process_outputs(model_input, + output_proc_callback) return outputs # should be [SamplerOutput] diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 2ed77dd698f5c..e0e421942f409 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -61,6 +62,13 @@ def _get_driver_input_and_broadcast( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback, + use_async_and_multi_step=execute_model_req. + use_async_and_multi_step) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine] From ea2c9899fd0118aa78f3f78910f8dea18ffe1958 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 13:40:56 +0000 Subject: [PATCH 2/9] rebase fixes --- vllm/engine/async_llm_engine.py | 7 ++-- vllm/engine/llm_engine.py | 24 ++++++++++---- vllm/sequence.py | 6 ---- vllm/worker/model_runner.py | 8 ++--- vllm/worker/multi_step_model_runner.py | 44 ++++++++++---------------- 5 files changed, 43 insertions(+), 46 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 2d635be950164..3058214c50a5f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -350,8 +350,11 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback execute_model_req.use_async_and_multi_step = \ use_async_and_multi_step diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c4d747aa5aff..5059cc514c676 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -39,10 +39,9 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import (AsyncCallbackData, EmbeddingSequenceGroupOutput, - ExecuteModelRequest, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, - SequenceStatus) +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + SamplerOutput, Sequence, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -92,7 +91,8 @@ class SchedulerOutputState: @dataclass class SchedulerContext: - output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata], + output_queue: Deque[Tuple[Optional[List[SamplerOutput]], + List[SequenceGroupMetadata], SchedulerOutputs]] = field( default_factory=lambda: deque()) @@ -433,6 +433,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for v_id in range(self.parallel_config.pipeline_parallel_size) ] + self.async_callback_multi_step = [ + functools.partial(self._process_model_outputs, + virtual_engine=v_id, + is_async=False) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1541,8 +1548,11 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback execute_model_req.use_async_and_multi_step = \ use_async_and_multi_step diff --git a/vllm/sequence.py b/vllm/sequence.py index a36de2da9616d..e7cde87f605a7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1262,12 +1262,6 @@ def expand_with_bonus_tokens( [self.hidden_states, self.second_last_token_hidden_states])[index] -@dataclass -class AsyncCallbackData: - func: Callable - kw_args: Dict[str, Any] - - class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c4bdfce2d6ce4..c7c43a2dc56ab 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) import numpy as np import torch @@ -41,8 +41,8 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (AsyncCallbackData, IntermediateTensors, - SamplerOutput, SequenceGroupMetadata) +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, supports_dynamo) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 91bea68a9528e..1a10fea6bf50e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,6 +1,7 @@ import dataclasses +import functools from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -14,9 +15,9 @@ from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.sequence import (AsyncCallbackData, CompletionSequenceGroupOutput, - IntermediateTensors, Logprob, SamplerOutput, - SequenceGroupMetadata, SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceOutput) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -217,29 +218,19 @@ def prepare_model_input( return model_input def _async_process_outputs(self, model_input: StatefulModelInput, - output_proc_callback: AsyncCallbackData): - output_proc_fn = output_proc_callback.func - output_proc_kw_args = output_proc_callback.kw_args - virtual_engine = output_proc_kw_args["virtual_engine"] - + output_proc_callback: Callable): for model_output in model_input.cached_outputs: if not model_output.pythonized: model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids) if model_output.pythonized: - output_proc_fn(virtual_engine=virtual_engine, - is_async=False, - sampler_output=model_output.sampler_output) + output_proc_callback( + sampler_output=model_output.sampler_output) def _final_process_outputs(self, model_input: StatefulModelInput, - output_proc_callback: AsyncCallbackData): + output_proc_callback: Optional[Callable]): assert model_input.frozen_model_input is not None - if output_proc_callback is not None: - output_proc_fn = output_proc_callback.func - output_proc_kw_args = output_proc_callback.kw_args - virtual_engine = output_proc_kw_args["virtual_engine"] - outputs = [] for output_id in range(len(model_input.cached_outputs)): is_last_output = output_id == len(model_input.cached_outputs) - 1 @@ -250,10 +241,9 @@ def _final_process_outputs(self, model_input: StatefulModelInput, self.pinned_sampled_token_ids) if model_input.frozen_model_input.use_async_and_multi_step: - output_proc_fn(virtual_engine=virtual_engine, - is_async=False, - sampler_output=output.sampler_output, - is_last_output=is_last_output) + assert output_proc_callback is not None + output_proc_callback(sampler_output=output.sampler_output, + is_last_output=is_last_output) outputs.append(output.sampler_output) @@ -318,11 +308,11 @@ def execute_model( output_proc_callback = None if frozen_model_input.use_async_and_multi_step: output_proc_callback = frozen_model_input.async_callback - async_callback = AsyncCallbackData( - self._async_process_outputs, { - "model_input": model_input, - "output_proc_callback": output_proc_callback - }) + assert output_proc_callback is not None + async_callback = functools.partial( + self._async_process_outputs, + model_input=model_input, + output_proc_callback=output_proc_callback) frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input, From dcc4824666ba11dd5e9dd35ddf29fc28da2dc98d Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 15:29:09 +0000 Subject: [PATCH 3/9] ping --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5059cc514c676..74503923073ed 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1468,7 +1468,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "as performance will be severely degraded otherwise.") # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0 + # used is always 0. virtual_engine = 0 # These are cached outputs from previous iterations. None if on first From b1c78f124bca64e50f9900e8f2e8e9dc48686dfa Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 17:17:09 +0000 Subject: [PATCH 4/9] ensure async order --- vllm/worker/multi_step_model_runner.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 1a10fea6bf50e..c671502841dc3 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -219,6 +219,9 @@ def prepare_model_input( def _async_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: Callable): + # Proceed with pythonization and output_proc in order. + # Stop on the first one that fails to pythonize + cont = True for model_output in model_input.cached_outputs: if not model_output.pythonized: model_output.maybe_pythonize(model_input, self._copy_stream, @@ -226,6 +229,11 @@ def _async_process_outputs(self, model_input: StatefulModelInput, if model_output.pythonized: output_proc_callback( sampler_output=model_output.sampler_output) + else: + cont = False + + if not cont: + break def _final_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: Optional[Callable]): From 3e2b7a88953e468778eb6f9ec11f9503cb1b3dd3 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 17:17:22 +0000 Subject: [PATCH 5/9] format --- vllm/worker/multi_step_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index c671502841dc3..0abca9d9f4558 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -219,7 +219,7 @@ def prepare_model_input( def _async_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: Callable): - # Proceed with pythonization and output_proc in order. + # Proceed with pythonization and output_proc in order. # Stop on the first one that fails to pythonize cont = True for model_output in model_input.cached_outputs: From c29e4da9014ab0966eceef0ff32943c831754129 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 20:29:49 +0000 Subject: [PATCH 6/9] Will's comments --- vllm/worker/multi_step_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 0abca9d9f4558..7bc0df105c0dc 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -314,7 +314,8 @@ def execute_model( model_input, model_input.cached_outputs[-1].sampler_output) output_proc_callback = None - if frozen_model_input.use_async_and_multi_step: + if (frozen_model_input.use_async_and_multi_step + and model_input.is_first_multi_step): output_proc_callback = frozen_model_input.async_callback assert output_proc_callback is not None async_callback = functools.partial( @@ -326,6 +327,7 @@ def execute_model( model_input.frozen_model_input, async_callback=async_callback) assert frozen_model_input is not None + model_input.frozen_model_input = frozen_model_input # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, From e26b18b35efc6b03f2b2c72221ccd71dbcf6a374 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 20:40:38 +0000 Subject: [PATCH 7/9] format --- vllm/worker/multi_step_model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 7bc0df105c0dc..32df0d43b316b 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -313,7 +313,6 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) - output_proc_callback = None if (frozen_model_input.use_async_and_multi_step and model_input.is_first_multi_step): output_proc_callback = frozen_model_input.async_callback @@ -376,8 +375,10 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = self._final_process_outputs(model_input, - output_proc_callback) + outputs = self._final_process_outputs( + model_input, + model_input.frozen_model_input.async_callback. # type: ignore + keywords["output_proc_callback"]) # type: ignore return outputs # should be [SamplerOutput] From 942bc1229bec2b9d1ce5a80a5f894f9145a846c8 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 29 Aug 2024 02:30:17 +0000 Subject: [PATCH 8/9] review comments --- vllm/engine/llm_engine.py | 6 ++++++ vllm/worker/multi_step_model_runner.py | 10 ++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 74503923073ed..92c02072593e6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1256,12 +1256,18 @@ def _process_model_outputs(self, """Apply the model output to the sequences in the scheduled seq groups. virtual_engine: The engine id to operate on + is_async: Indicates whether this postprocessor runs in parallel with the GPU forward pass and is processing tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) + sampler_output: Used with multi-step execution to provide + sampler_output of each step + is_last_output: Used with multi-step execution to indicate + the last step (of each multi-step group) + Returns RequestOutputs that can be returned to the client. """ now = time.time() diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 32df0d43b316b..1c10752006134 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -375,10 +375,12 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = self._final_process_outputs( - model_input, - model_input.frozen_model_input.async_callback. # type: ignore - keywords["output_proc_callback"]) # type: ignore + assert model_input.frozen_model_input is not None + async_callback = model_input.frozen_model_input.async_callback # type: ignore + output_proc_callback = async_callback.keywords[ + "output_proc_callback"] if async_callback is not None else None + outputs = self._final_process_outputs(model_input, + output_proc_callback) return outputs # should be [SamplerOutput] From dbbde98f5d896057fd908e4a7fa3f1f0d41234d6 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 29 Aug 2024 02:47:59 +0000 Subject: [PATCH 9/9] fix --- vllm/worker/multi_step_model_runner.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 1c10752006134..0abca9d9f4558 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -313,8 +313,8 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) - if (frozen_model_input.use_async_and_multi_step - and model_input.is_first_multi_step): + output_proc_callback = None + if frozen_model_input.use_async_and_multi_step: output_proc_callback = frozen_model_input.async_callback assert output_proc_callback is not None async_callback = functools.partial( @@ -326,7 +326,6 @@ def execute_model( model_input.frozen_model_input, async_callback=async_callback) assert frozen_model_input is not None - model_input.frozen_model_input = frozen_model_input # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, @@ -375,10 +374,6 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - assert model_input.frozen_model_input is not None - async_callback = model_input.frozen_model_input.async_callback # type: ignore - output_proc_callback = async_callback.keywords[ - "output_proc_callback"] if async_callback is not None else None outputs = self._final_process_outputs(model_input, output_proc_callback) return outputs