Skip to content

Commit

Permalink
[Core] Add multi-step support to LLMEngine (vllm-project#7789)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic authored and omrishiv committed Aug 26, 2024
1 parent 3fcb1b6 commit b63d237
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 87 deletions.
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ steps:
- vllm/engine
- tests/multi_step
commands:
- pytest -v -s multi_step/test_correctness.py
- pytest -v -s multi_step/test_correctness_async_llm.py
- pytest -v -s multi_step/test_correctness_llm.py

- label: Pipeline Parallelism Test # 23min
working_dir: "/vllm-workspace/tests"
Expand Down
17 changes: 15 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def run_vllm(
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
) -> float:
Expand All @@ -106,6 +108,8 @@ def run_vllm(
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -232,7 +236,8 @@ def main(args: argparse.Namespace):
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.download_dir, args.load_format)
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -353,10 +358,18 @@ def main(args: argparse.Namespace):
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument(
"--num-scheduler-steps",
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--use-v2-block-manager",
action='store_true',
help="Enable block manager v2.")
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
help="Enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files):
expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
"so little time.\nAuthor: Frank Zappa\n",
"so little time\nAuthor: Frank Zappa\n",
]

output1 = do_sample(llm, gemma_lora_files, lora_id=1)
Expand Down
File renamed without changes.
49 changes: 49 additions & 0 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Test the LLMEngine with multi-step-decoding

import pytest

from ..models.utils import check_outputs_equal

MODELS = [
"JackFram/llama-160m",
]
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS = [10]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, tp_size: int, max_tokens: int,
enforce_eager: int, num_scheduler_steps: int,
num_prompts: int) -> None:

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

with vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
74 changes: 2 additions & 72 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)

import torch
from typing_extensions import assert_never

import vllm.envs as envs
Expand All @@ -15,7 +13,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
PromptComponents, SchedulerOutputState)
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
Expand All @@ -28,8 +26,7 @@
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
Expand Down Expand Up @@ -257,24 +254,11 @@ def has_new_requests(self):
return not self._new_requests.empty()


@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None


class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pipeline_parallel_size = \
self.parallel_config.pipeline_parallel_size
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(pipeline_parallel_size)
]

async def step_async(
self, virtual_engine: int
Expand Down Expand Up @@ -367,60 +351,6 @@ async def step_async(

return request_outputs

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False

# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))

return ref_remaining_steps > 0

def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None

def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None

def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
Expand Down
Loading

0 comments on commit b63d237

Please sign in to comment.