diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh new file mode 100644 index 000000000000..c4e6b21d074f --- /dev/null +++ b/.buildkite/run-benchmarks.sh @@ -0,0 +1,35 @@ +# This script is run by buildkite to run the benchmarks and upload the results to buildkite + +set -ex + +# cd into parent directory of this file +cd "$(dirname "${BASH_SOURCE[0]}")/.." + +# run benchmarks and upload the result to buildkite +python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt +bench_latency_exit_code=$? + +python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt +bench_throughput_exit_code=$? + +# write the results into a markdown file +echo "### Latency Benchmarks" >> benchmark_results.md +sed -n '1p' benchmark_latency.txt >> benchmark_results.md +echo "" >> benchmark_results.md +sed -n '$p' benchmark_latency.txt >> benchmark_results.md +echo "### Throughput Benchmarks" >> benchmark_results.md +sed -n '1p' benchmark_throughput.txt >> benchmark_results.md +echo "" >> benchmark_results.md +sed -n '$p' benchmark_throughput.txt >> benchmark_results.md + +# upload the results to buildkite +/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md + +# exit with the exit code of the benchmarks +if [ $bench_latency_exit_code -ne 0 ]; then + exit $bench_latency_exit_code +fi + +if [ $bench_throughput_exit_code -ne 0 ]; then + exit $bench_throughput_exit_code +fi diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml new file mode 100644 index 000000000000..a6f3a3f0a2e3 --- /dev/null +++ b/.buildkite/test-pipeline.yaml @@ -0,0 +1,44 @@ +# In this file, you can add more tests to run either by adding a new step or +# adding a new command to an existing step. See different options here for examples. +# This script will be feed into Jinja template in `test-template.j2` to generate +# the final pipeline yaml file. + +steps: +- label: Regression Test + command: pytest -v -s test_regression.py + working_dir: "/vllm-workspace/tests" # optional + +- label: AsyncEngine Test + command: pytest -v -s async_engine + +- label: Distributed Test + command: pytest -v -s test_comm_ops.py + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 2 # only support 1 or 2 for now. + +- label: Engine Test + command: pytest -v -s engine + +- label: Entrypoints Test + command: pytest -v -s entrypoints + +- label: Kernels Test + command: pytest -v -s kernels + soft_fail: true + +- label: Models Test + commands: + - pytest -v -s models --forked + soft_fail: true + +- label: Samplers Test + command: pytest -v -s samplers --forked + +- label: Worker Test + command: pytest -v -s worker + +- label: Benchmarks + working_dir: "/vllm-workspace/.buildkite" + commands: + - pip install aiohttp + - bash run-benchmarks.sh diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 new file mode 100644 index 000000000000..b35511293539 --- /dev/null +++ b/.buildkite/test-template.j2 @@ -0,0 +1,54 @@ +{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %} +{% set default_num_gpu = 1 %} +{% set default_working_dir = "/vllm-workspace/tests" %} + +steps: + - label: ":docker: build image" + commands: + - "docker build --tag {{ docker_image }} --target test --progress plain ." + - "docker push {{ docker_image }}" + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + - wait + + {% for step in steps %} + - label: "{{ step.label }}" + agents: + queue: kubernetes + soft_fail: {{ step.soft_fail or false }} + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + plugins: + - kubernetes: + podSpec: + volumes: + - name: dshm + emptyDir: + medium: Memory + containers: + - image: "{{ docker_image }}" + command: ["bash"] + args: + - "-c" + - "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" + resources: + requests: + nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" + limits: + nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" + env: + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + volumeMounts: + - mountPath: /dev/shm + name: dshm + {% endfor %} diff --git a/Dockerfile b/Dockerfile index bd66afe79c7e..44b1dd17d7e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,11 @@ +# The vLLM Dockerfile is used to construct vLLM image that can be directly used +# to run the OpenAI compatible server. + +#################### BASE BUILD IMAGE #################### FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev RUN apt-get update -y \ - && apt-get install -y python3-pip + && apt-get install -y python3-pip git WORKDIR /workspace @@ -14,8 +18,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ COPY requirements-dev.txt requirements-dev.txt RUN --mount=type=cache,target=/root/.cache/pip \ pip install -r requirements-dev.txt +#################### BASE BUILD IMAGE #################### + -# image to build pytorch extensions +#################### EXTENSION BUILD IMAGE #################### FROM dev AS build # install build dependencies @@ -30,6 +36,7 @@ COPY requirements.txt requirements.txt COPY pyproject.toml pyproject.toml COPY vllm/__init__.py vllm/__init__.py +# cuda arch list used by torch ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} # max jobs used by Ninja to build extensions @@ -40,18 +47,26 @@ ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads RUN python3 setup.py build_ext --inplace +#################### EXTENSION Build IMAGE #################### + +#################### TEST IMAGE #################### # image to run unit testing suite FROM dev AS test # copy pytorch extensions separately to avoid having to rebuild # when python code changes -COPY --from=build /workspace/vllm/*.so /workspace/vllm/ -COPY tests tests -COPY vllm vllm +WORKDIR /vllm-workspace +# ADD is used to preserve directory structure +ADD . /vllm-workspace/ +COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/ +# ignore build dependencies installation because we are using pre-complied extensions +RUN rm pyproject.toml +RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose +#################### TEST IMAGE #################### -ENTRYPOINT ["python3", "-m", "pytest", "tests"] +#################### RUNTIME BASE IMAGE #################### # use CUDA base as CUDA runtime dependencies are already installed via pip FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base @@ -63,14 +78,10 @@ WORKDIR /workspace COPY requirements.txt requirements.txt RUN --mount=type=cache,target=/root/.cache/pip \ pip install -r requirements.txt +#################### RUNTIME BASE IMAGE #################### -FROM vllm-base AS vllm -COPY --from=build /workspace/vllm/*.so /workspace/vllm/ -COPY vllm vllm - -EXPOSE 8000 -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"] +#################### OPENAI API SERVER #################### # openai api server alternative FROM vllm-base AS vllm-openai # install additional dependencies for openai api server @@ -81,3 +92,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/ COPY vllm vllm ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +#################### OPENAI API SERVER #################### diff --git a/README.md b/README.md index 8ea4d029dc64..d30f56712996 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,15 @@ Easy, fast, and cheap LLM serving for everyone --- +**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)** + +We are thrilled to announce our second vLLM Meetup! +The vLLM team will share recent updates and roadmap. +We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations. +Please register [here](https://lu.ma/ygxbpzhl) and join us! + +--- + *Latest News* 🔥 - [2023/12] Added ROCm support to vLLM. - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). @@ -68,6 +77,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) +- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): diff --git a/docs/source/conf.py b/docs/source/conf.py index d0c64cf53230..44c976468ab0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,11 +9,15 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) +import os +import sys +from sphinx.ext import autodoc +import logging + +sys.path.insert(0, os.path.abspath(os.path.join('..', '..'))) + +logger = logging.getLogger(__name__) # -- Project information ----------------------------------------------------- @@ -21,7 +25,6 @@ copyright = '2023, vLLM Team' author = 'the vLLM Team' - # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be @@ -32,6 +35,8 @@ "sphinx.ext.viewcode", "sphinx.ext.intersphinx", "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", ] # Add any paths that contain templates here, relative to this directory. @@ -55,7 +60,6 @@ html_theme = 'sphinx_book_theme' html_logo = 'assets/logos/vllm-logo-text-light.png' html_theme_options = { - 'logo_only': True, 'path_to_docs': 'docs/source', 'repository_url': 'https://github.com/vllm-project/vllm', 'use_repository_button': True, @@ -64,4 +68,29 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +# html_static_path = ['_static'] + +# Mock out external dependencies here. +autodoc_mock_imports = [ + "torch", "transformers", "psutil", "aioprometheus", "sentencepiece", + "vllm.cuda_utils", "vllm._C" +] + +for mock_target in autodoc_mock_imports: + if mock_target in sys.modules: + logger.info( + f"Potentially problematic mock target ({mock_target}) found; " + "autodoc_mock_imports cannot mock modules that have already " + "been loaded into sys.modules when the sphinx build starts.") + + +class MockedClassDocumenter(autodoc.ClassDocumenter): + """Remove note about base class when a class is derived from object.""" + + def add_line(self, line: str, source: str, *lineno: int) -> None: + if line == " Bases: :py:class:`object`": + return + super().add_line(line, source, *lineno) + + +autodoc.ClassDocumenter = MockedClassDocumenter diff --git a/docs/source/dev/engine/async_llm_engine.rst b/docs/source/dev/engine/async_llm_engine.rst new file mode 100644 index 000000000000..47db1e0a401b --- /dev/null +++ b/docs/source/dev/engine/async_llm_engine.rst @@ -0,0 +1,7 @@ + +AsyncLLMEngine +================================= + +.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine + :members: generate, abort + :show-inheritance: diff --git a/docs/source/dev/engine/engine_index.rst b/docs/source/dev/engine/engine_index.rst new file mode 100644 index 000000000000..ba9ae55ddea4 --- /dev/null +++ b/docs/source/dev/engine/engine_index.rst @@ -0,0 +1,13 @@ +vLLM Engine +================================= + +.. automodule:: vllm.engine +.. currentmodule:: vllm.engine + +.. toctree:: + :maxdepth: 2 + :caption: Engines + + llm_engine + async_llm_engine + diff --git a/docs/source/dev/engine/llm_engine.rst b/docs/source/dev/engine/llm_engine.rst new file mode 100644 index 000000000000..b550a9b5faa6 --- /dev/null +++ b/docs/source/dev/engine/llm_engine.rst @@ -0,0 +1,6 @@ +LLMEngine +================================= + +.. autoclass:: vllm.engine.llm_engine.LLMEngine + :members: add_request, abort_request, step, _init_cache + :show-inheritance: \ No newline at end of file diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 1a423b64f7e4..5ce3c096cb44 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -11,6 +11,14 @@ This guide shows how to use vLLM to: Be sure to complete the :ref:`installation instructions ` before continuing with this guide. +.. note:: + + By default, vLLM downloads model from `HuggingFace `_. If you would like to use models from `ModelScope `_ in the following examples, please set the environment variable: + + .. code-block:: shell + + export VLLM_USE_MODELSCOPE=True + Offline Batched Inference ------------------------- @@ -40,16 +48,6 @@ Initialize vLLM's engine for offline inference with the ``LLM`` class and the `O llm = LLM(model="facebook/opt-125m") -Use model from www.modelscope.cn - -.. code-block:: shell - - export VLLM_USE_MODELSCOPE=True - -.. code-block:: python - - llm = LLM(model="qwen/Qwen-7B-Chat", revision="v1.1.8", trust_remote_code=True) - Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens. .. code-block:: python @@ -77,16 +75,6 @@ Start the server: $ python -m vllm.entrypoints.api_server -Use model from www.modelscope.cn - -.. code-block:: console - - $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.api_server \ - $ --model="qwen/Qwen-7B-Chat" \ - $ --revision="v1.1.8" \ - $ --trust-remote-code - - By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model. Query the model in shell: @@ -107,7 +95,7 @@ OpenAI-Compatible Server ------------------------ vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. -By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models `_, `create chat completion `_, and `create completion `_ endpoints. We are actively adding support for more endpoints. +By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models `_, `create chat completion `_, and `create completion `_ endpoints. We are actively adding support for more endpoints. Start the server: @@ -116,13 +104,6 @@ Start the server: $ python -m vllm.entrypoints.openai.api_server \ $ --model facebook/opt-125m -Use model from www.modelscope.cn - -.. code-block:: console - - $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \ - $ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code - By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument: .. code-block:: console diff --git a/docs/source/index.rst b/docs/source/index.rst index 816f4f7e2015..321f855645bb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -85,4 +85,16 @@ Documentation :maxdepth: 1 :caption: Quantization - quantization/auto_awq \ No newline at end of file + quantization/auto_awq + +.. toctree:: + :maxdepth: 2 + :caption: Developer Documentation + + dev/engine/engine_index + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 361ad5f5a22b..1c5ab9f6592c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -68,6 +68,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. + * - :code:`StableLMEpochForCausalLM` + - StableLM + - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. * - :code:`YiForCausalLM` - Yi - :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. diff --git a/examples/gradio_openai_chatbot_webserver.py b/examples/gradio_openai_chatbot_webserver.py new file mode 100644 index 000000000000..61e91d6b0c8b --- /dev/null +++ b/examples/gradio_openai_chatbot_webserver.py @@ -0,0 +1,81 @@ +import argparse +from openai import OpenAI +import gradio as gr + +# Argument parser setup +parser = argparse.ArgumentParser( + description='Chatbot Interface with Customizable Parameters') +parser.add_argument('--model-url', + type=str, + default='http://localhost:8000/v1', + help='Model URL') +parser.add_argument('-m', + '--model', + type=str, + required=True, + help='Model name for the chatbot') +parser.add_argument('--temp', + type=float, + default=0.8, + help='Temperature for text generation') +parser.add_argument('--stop-token-ids', + type=str, + default='', + help='Comma-separated stop token IDs') +parser.add_argument("--host", type=str, default=None) +parser.add_argument("--port", type=int, default=8001) + +# Parse the arguments +args = parser.parse_args() + +# Set OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = args.model_url + +# Create an OpenAI client to interact with the API server +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + + +def predict(message, history): + # Convert chat history to OpenAI format + history_openai_format = [{ + "role": "system", + "content": "You are a great ai assistant." + }] + for human, assistant in history: + history_openai_format.append({"role": "user", "content": human}) + history_openai_format.append({ + "role": "assistant", + "content": assistant + }) + history_openai_format.append({"role": "user", "content": message}) + + # Create a chat completion request and send it to the API server + stream = client.chat.completions.create( + model=args.model, # Model name to use + messages=history_openai_format, # Chat history + temperature=args.temp, # Temperature for text generation + stream=True, # Stream response + extra_body={ + 'repetition_penalty': + 1, + 'stop_token_ids': [ + int(id.strip()) for id in args.stop_token_ids.split(',') + if id.strip() + ] if args.stop_token_ids else [] + }) + + # Read and return generated text from response stream + partial_message = "" + for chunk in stream: + partial_message += (chunk.choices[0].delta.content or "") + yield partial_message + + +# Create and launch a chat interface with Gradio +gr.ChatInterface(predict).queue().launch(server_name=args.host, + server_port=args.port, + share=True) diff --git a/examples/template_baichuan.jinja b/examples/template_baichuan.jinja new file mode 100644 index 000000000000..a1812a6c09ab --- /dev/null +++ b/examples/template_baichuan.jinja @@ -0,0 +1,22 @@ +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} + +{% for message in messages %} +{% if message['role'] == 'user' %} + +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'assistant' %} + +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} + +{% endif %} \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index cf1529274908..f8126008d079 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,4 +13,9 @@ types-setuptools pytest pytest-forked pytest-asyncio - +httpx +einops # required for MPT +flash_attn # required for HuggingFace's llama implementation +openai +requests +ray \ No newline at end of file diff --git a/setup.py b/setup.py index 811d494e7a01..fe8cd6d75ed7 100644 --- a/setup.py +++ b/setup.py @@ -293,6 +293,11 @@ def get_requirements() -> List[str]: return requirements +package_data = {"vllm": ["py.typed"]} +if os.environ.get("VLLM_USE_PRECOMPILED"): + ext_modules = [] + package_data["vllm"].append("*.so") + setuptools.setup( name="vllm", version=get_vllm_version(), @@ -321,5 +326,5 @@ def get_requirements() -> List[str]: install_requires=get_requirements(), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, - package_data={"vllm": ["py.typed"]}, + package_data=package_data, ) diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 0b45e10dc555..ed9017c1e3e9 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -29,8 +29,13 @@ def api_server(): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() uvicorn_process = subprocess.Popen([ - sys.executable, "-u", - str(script_path), "--model", "facebook/opt-125m" + sys.executable, + "-u", + str(script_path), + "--model", + "facebook/opt-125m", + "--host", + "127.0.0.1", ]) yield uvicorn_process.terminate() @@ -81,6 +86,9 @@ def test_api_server(api_server): pool.join() # check cancellation stats + # give it some times to update the stats + time.sleep(1) + num_aborted_requests = requests.get( "http://localhost:8000/stats").json()["num_aborted_requests"] assert num_aborted_requests > 0 diff --git a/tests/async_engine/test_openai_server.py b/tests/async_engine/test_chat_template.py similarity index 71% rename from tests/async_engine/test_openai_server.py rename to tests/async_engine/test_chat_template.py index a61ff7e84ca6..32d110e0f0b4 100644 --- a/tests/async_engine/test_openai_server.py +++ b/tests/async_engine/test_chat_template.py @@ -1,10 +1,16 @@ -from argparse import Namespace from dataclasses import dataclass +import os +import pathlib import pytest -from fastapi.testclient import TestClient -from vllm.entrypoints.openai.api_server import * +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.protocol import ChatCompletionRequest + +chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( + __file__))).parent.parent / "examples/template_chatml.jinja" +assert chatml_jinja_path.exists() # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATON_OUTPUT = [ @@ -12,8 +18,7 @@ "HelloHi there!What is the capital of"), ("facebook/opt-125m", None, False, "HelloHi there!What is the capital of"), - ("facebook/opt-125m", "../../examples/template_chatml.jinja", True, - """<|im_start|>user + ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> @@ -21,8 +26,7 @@ What is the capital of<|im_end|> <|im_start|>assistant """), - ("facebook/opt-125m", "../../examples/template_chatml.jinja", False, - """<|im_start|>user + ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> @@ -44,7 +48,6 @@ 'content': 'What is the capital of' }, ] -client = TestClient(app) @dataclass @@ -52,14 +55,17 @@ class MockTokenizer: chat_template = None +@dataclass +class MockServingChat: + tokenizer: MockTokenizer + + def test_load_chat_template(): # Testing chatml template - template = "../../examples/template_chatml.jinja" - mock_args = Namespace(chat_template=template) tokenizer = MockTokenizer() - - # Call the function with the mocked args - load_chat_template(mock_args, tokenizer) + mock_serving_chat = MockServingChat(tokenizer) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -73,11 +79,11 @@ def test_load_chat_template(): def test_no_load_chat_template(): # Testing chatml template template = "../../examples/does_not_exist" - mock_args = Namespace(chat_template=template) tokenizer = MockTokenizer() - # Call the function with the mocked args - load_chat_template(mock_args, tokenizer=tokenizer) + mock_serving_chat = MockServingChat(tokenizer) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) template_content = tokenizer.chat_template # Test assertions @@ -94,9 +100,9 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) - - mock_args = Namespace(chat_template=template) - load_chat_template(mock_args, tokenizer) + mock_serving_chat = MockServingChat(tokenizer) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( @@ -112,8 +118,3 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, # Test assertion assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}" - - -def test_health_endpoint(): - response = client.get("/health") - assert response.status_code == 200 diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index b9895b3e7179..75111feb3950 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -2,10 +2,9 @@ Run `pytest tests/distributed/test_comm_ops.py --forked`. """ -from multiprocessing import Process, set_start_method - import pytest import torch +import ray from vllm.config import ParallelConfig from vllm.utils import get_open_port @@ -23,11 +22,11 @@ def init_test_distributed_environment(pipeline_parallel_size: int, tensor_parallel_size, worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - torch.cuda.set_device(rank) _init_distributed_environment(parallel_config, rank, distributed_init_method) +@ray.remote(num_gpus=1, max_calls=1) def all_reduce_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): init_test_distributed_environment(1, tensor_parallel_size, rank, @@ -43,6 +42,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, assert torch.allclose(t, expected) +@ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): init_test_distributed_environment(1, tensor_parallel_size, rank, @@ -70,14 +70,16 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.parametrize("test_target", [all_reduce_test_worker, all_gather_test_worker]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - set_start_method("spawn", force=True) + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + ray.init() + distributed_init_port = get_open_port() - processes = [] + refs = [] for rank in range(tensor_parallel_size): - p = Process(target=test_target, - args=(tensor_parallel_size, rank, distributed_init_port)) - p.start() - processes.append(p) - for p in processes: - p.join() - assert all(p.exitcode == 0 for p in processes) + refs.append( + test_target.remote(tensor_parallel_size, rank, + distributed_init_port)) + ray.get(refs) + + ray.shutdown() diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py new file mode 100644 index 000000000000..707ab6d28d92 --- /dev/null +++ b/tests/entrypoints/test_openai_server.py @@ -0,0 +1,193 @@ +import time +import subprocess + +import sys +import pytest +import requests +import ray # using Ray for overall ease of process management, parallel requests, and debugging. +import openai # use the official client for correctness check + +MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here + +pytestmark = pytest.mark.asyncio + + +@ray.remote(num_gpus=1) +class ServerRunner: + + def __init__(self, args): + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +@pytest.fixture(scope="session") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + "--dtype", + "bfloat16", # use half precision for speed and memory savings in CI environment + "--max-model-len", + "8192" + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +async def test_single_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + +async def test_single_chat_session(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + ) + assert chat_completion.id is not None + assert chat_completion.choices is not None and len( + chat_completion.choices) == 1 + assert chat_completion.choices[0].message is not None + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +async def test_completion_streaming(server, client: openai.AsyncOpenAI): + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + single_usage = single_completion.usage + + stream = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk.choices[0].text) + assert chunk.choices[0].finish_reason == "length" + assert chunk.usage == single_usage + assert "".join(chunks) == single_output + + +async def test_chat_streaming(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + output = chat_completion.choices[0].message.content + stop_reason = chat_completion.choices[0].finish_reason + + # test streaming + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=True, + ) + chunks = [] + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + if delta.content: + chunks.append(delta.content) + assert chunk.choices[0].finish_reason == stop_reason + assert "".join(chunks) == output + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 814d40f56def..3949948e860f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -13,7 +13,7 @@ # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -NUM_BLOCKS = 40000 # Arbitrary values for testing +NUM_BLOCKS = 12000 # Arbitrary values for testing PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 1d8d41e013b0..7b1cc058f2cb 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -6,12 +6,12 @@ from vllm._C import cache_ops DTYPES = [torch.half, torch.bfloat16, torch.float] -NUM_TOKENS = [83] # Arbitrary values for testing +NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] BLOCK_SIZES = [8, 16, 32] -NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing +NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 518eae201ed3..40858a517b31 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,18 +5,11 @@ import pytest MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", + "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", + "bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", + "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t" ] diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 1c67cc5bd739..0ea3704462fc 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -30,6 +30,7 @@ def test_get_prompt_logprobs( temperature=0.0) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) + del vllm_model # Test whether logprobs are included in the results. for result in vllm_results: diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py new file mode 100644 index 000000000000..9d3ef3c67d3d --- /dev/null +++ b/tests/samplers/test_rejection_sampler.py @@ -0,0 +1,392 @@ +"""Tests for rejection sampling.""" +import pytest +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +from vllm.model_executor.utils import set_random_seed + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler + + +def mock_causal_accepted_tensor( + k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor: + """Generate an "accepted" tensor which should yield causally-accepted tokens + up to last accepted indices. + + Tokens after last_accepted_indices+1 may also be accepted, although they + will not be causally accepted. + """ + batch_size = last_accepted_indices.shape[0] + + accepted = (torch.arange(k).expand(batch_size, k) <= + last_accepted_indices.unsqueeze(-1).broadcast_to( + batch_size, k)).to(device="cuda") + + # Sprinkle accepted values after the contiguous initial accepted values. + # This replicates the behavior of rejection sampling, which may "accept" + # a token that cannot be accepted because of causality. + sprinkle_candidates = ( + torch.arange(k).expand(batch_size, k) > + last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1) + sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5 + accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] + return accepted + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize( + "which_tokens_accepted", + ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@torch.inference_mode() +def test_correct_output_format(which_tokens_accepted: str, seed: int): + """Verify the output has correct format given predetermined accepted matrix. + """ + set_random_seed(seed) + + batch_size = 10 + k = 5 + vocab_size = 3000 + + if which_tokens_accepted == "all_tokens_accepted": + accepted = mock_causal_accepted_tensor( + k, -1 + k * torch.ones((batch_size, ), dtype=torch.long)) + elif which_tokens_accepted == "no_tokens_accepted": + accepted = mock_causal_accepted_tensor( + k, -torch.ones((batch_size, ), dtype=torch.long)) + elif which_tokens_accepted == "some_tokens_accepted": + last_accepted_indices = torch.randint(low=-1, + high=k, + size=(batch_size, )) + accepted = mock_causal_accepted_tensor(k, last_accepted_indices) + else: + raise AssertionError() + + recovered_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + + rejection_sampler = RejectionSampler() + rejection_sampler.init_gpu_tensors(rank=0) + output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + + if which_tokens_accepted == "all_tokens_accepted": + # Expect all tokens to be equal to draft tokens. + assert torch.equal(output_token_ids[:, :-1], draft_token_ids) + + # Expect all bonus tokens to be included. + assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) + elif which_tokens_accepted == "no_tokens_accepted": + # Expect first token to be equal to recovered tokens. + assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) + + # Expect everything else to be -1. + assert torch.equal(output_token_ids[:, 1:], + torch.ones_like(output_token_ids[:, 1:]) * -1) + elif which_tokens_accepted == "some_tokens_accepted": + recovered_plus_bonus = torch.cat( + (recovered_token_ids, bonus_token_ids), dim=-1) + # Assert first rejected token is a recovered token or bonus token. + assert torch.equal( + recovered_plus_bonus[torch.arange(0, batch_size), + last_accepted_indices + 1], + output_token_ids[torch.arange(0, batch_size), + last_accepted_indices + 1]) + + # Assert every subsequent token is -1. + subsequent_mask = torch.arange(0, k + 1).expand( + batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1) + assert torch.all(output_token_ids[subsequent_mask] == -1) + + +@pytest.mark.parametrize("k", list(range(1, 6))) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", list(range(1, 32))) +@torch.inference_mode() +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int): + rejection_sampler = RejectionSampler() + rejection_sampler.init_gpu_tensors(rank=0) + + draft_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + target_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids) + + +@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) +@pytest.mark.parametrize("which_token_ids", + ["bonus_token_ids", "draft_token_ids"]) +@torch.inference_mode() +def test_raises_when_vocab_oob(above_or_below_vocab_range: str, + which_token_ids: str): + k = 3 + batch_size = 5 + vocab_size = 30_000 + + rejection_sampler = RejectionSampler(strict_mode=True) + rejection_sampler.init_gpu_tensors(rank=0) + + draft_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + target_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + + oob_token_ids = None + if which_token_ids == "bonus_token_ids": + oob_token_ids = bonus_token_ids + elif which_token_ids == "draft_token_ids": + oob_token_ids = draft_token_ids + else: + raise AssertionError() + + if above_or_below_vocab_range == "above": + rogue_token_id = vocab_size + 1 + elif above_or_below_vocab_range == "below": + rogue_token_id = -1 + else: + raise AssertionError() + + oob_token_ids[0][0] = rogue_token_id + + with pytest.raises(AssertionError): + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids) + + +@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) +@pytest.mark.parametrize("seed", list(range(5))) +@torch.inference_mode() +def test_rejection_sampling_approximates_target_distribution( + seed: int, draft_and_target_probs_equal: bool): + """Verify rejection sampling approximates target distribution, + despite sampling from a potentially distinct draft distribution. + + This is done by first creating a random target probability + distribution and a random draft probability distribution. We then + sample token ids from the rejection sampler using these draft + and target distributions. The samples are used to estimate + the output probability distribution, which we expect to approximate + the target distribution. + + A basic distance metric is used to determine similarity between + distributions. + + We expect that as we increase the number of samples, + the distance between the observed distribution and the target + distribution decreases. To measure this, we compare the distance + of the observed distribution against both the target distribution + and a uniform random distribution. We expect the distance between + the observed distribution and the target distribution to improve + much more than the distance improvement between the observed + distribution and the random distribution. + + When draft_and_target_probs_equal=True, the draft and target + probabilities are exactly equal. Rejection sampling should + still work without any NaNs or exceptions. + """ + set_random_seed(seed) + + helper = _CorrectnessTestHelper( + vocab_size=10, + rejection_sampler=RejectionSampler(), + ) + + draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( + draft_and_target_probs_equal) + + sample_sizes = [10, 100, 1_000, 10_000, 100_000] + distance_wrt_reference = [] + distance_wrt_target = [] + + for num_samples in sample_sizes: + (reference_vs_rejsample_dist, + target_vs_rejsample_dist) = helper.run_and_compare_distributions( + draft_probs, + target_probs, + reference_probs, + num_samples, + ) + + distance_wrt_reference.append(reference_vs_rejsample_dist) + distance_wrt_target.append(target_vs_rejsample_dist) + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " + f"{reference_vs_rejsample_dist=:.05f}") + print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " + f"{relative_change_in_distance_wrt_reference=:.02f}") + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + expected_improvement_multiplier = 20 + assert (relative_change_in_distance_wrt_target > + relative_change_in_distance_wrt_reference * + expected_improvement_multiplier) + + +def get_ratio_first_to_last(elements: List[float]) -> float: + return elements[0] / elements[-1] + + +class _CorrectnessTestHelper: + """Class that packages together logic required for the unit-level + rejection sampling correctness test. + """ + + def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler): + self.rejection_sampler = rejection_sampler + self.vocab_size = vocab_size + self.vocab_range = (0, vocab_size) + + self.rejection_sampler.init_gpu_tensors(rank=0) + + # Keep test simple, use k=1 + self.k = 1 + + # Bonus tokens not used, but rejection sampler requires + # correct shape. + self.num_bonus_tokens = 1 + + def generate_probs_for_test( + self, draft_and_target_probs_equal: bool + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + draft_probs, target_probs = [ + F.softmax( + torch.rand(self.vocab_size, dtype=torch.float32), + dim=-1, + ) for _ in range(2) + ] + + num_reference_probs = 100 + reference_probs = F.softmax( + torch.rand(num_reference_probs, + self.vocab_size, + dtype=torch.float32), + dim=-1, + ) + + if draft_and_target_probs_equal: + target_probs = draft_probs.clone() + + return draft_probs, target_probs, reference_probs + + def run_and_compare_distributions(self, draft_probs: torch.Tensor, + target_probs: torch.Tensor, + reference_probs: torch.Tensor, + num_samples: int) -> Tuple[float, float]: + # Sample using rejection sampling. + rej_sample_probs = self._estimate_rejection_sampling_pdf( + draft_probs, target_probs, num_samples) + + # Average distance from reference probs. + reference_vs_rejsample_dist = torch.dist( + reference_probs, + rej_sample_probs).item() / reference_probs.shape[0] + target_vs_rejsample_dist = torch.dist(target_probs, + rej_sample_probs).item() + + return reference_vs_rejsample_dist, target_vs_rejsample_dist + + def _estimate_rejection_sampling_pdf( + self, + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + num_samples: int, + ) -> torch.Tensor: + # Repeat draft probs num_samples times. + draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( + num_samples, 1, 1) + + # Repeat target probs num_samples * k times. + # Rejection sampler requires bonus token probs, but they aren't used. + target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( + num_samples, self.k, 1) + + # Randomly sample draft token ids from draft probs. + draft_token_ids = torch.multinomial(draft_probs[:, 0, :], + num_samples=1, + replacement=True).reshape( + num_samples, self.k) + + # Bonus tokens not used but required. + bonus_token_ids = torch.zeros((1, self.num_bonus_tokens), + dtype=torch.int64, + device="cuda").repeat(num_samples, 1) + + # Get output tokens via rejection sampling. + output_token_ids = self.rejection_sampler(target_probs.to("cuda"), + bonus_token_ids.to("cuda"), + draft_probs.to("cuda"), + draft_token_ids.to("cuda")) + + # Remove bonus tokens + output_token_ids = output_token_ids[:, :-1].flatten() + + # Estimate probability density function + hist = torch.histogram(output_token_ids.to(dtype=torch.float, + device="cpu"), + bins=self.vocab_size, + range=self.vocab_range, + density=True) + + return hist.hist diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 3ad2d4608fbd..996aa8e0a8d9 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -4,6 +4,7 @@ import pytest import torch +from transformers import GenerationConfig, GenerationMixin from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.utils import set_random_seed @@ -74,6 +75,8 @@ def test_sampler_all_greedy(seed: int): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() + del model_runner + @pytest.mark.parametrize("seed", RANDOM_SEEDS) def test_sampler_all_random(seed: int): @@ -110,6 +113,8 @@ def test_sampler_all_random(seed: int): for nth_output in sequence_output.samples: assert nth_output.output_token == i + del model_runner + @pytest.mark.parametrize("seed", RANDOM_SEEDS) def test_sampler_all_beam(seed: int): @@ -143,6 +148,7 @@ def test_sampler_all_beam(seed: int): # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. + del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -197,6 +203,8 @@ def test_sampler_mixed(seed: int): for nth_output in sequence_output.samples: assert nth_output.output_token in expected_tokens + del model_runner + @pytest.mark.parametrize("seed", RANDOM_SEEDS) def test_sampler_logits_processors(seed: int): @@ -233,3 +241,69 @@ def pick_ith(token_ids, logits): for _, sequence_output in enumerate(sampler_output): for idx, nth_output in enumerate(sequence_output.samples): assert nth_output.output_token == idx + + del model_runner + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_top_k_top_p(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + top_k = random.randint(100, 500) + top_p = random.random() * 0.1 + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), + device="cuda", + dtype=torch.float16) + fake_logits = torch.normal(0, + 5, + size=(batch_size, vocab_size), + device=input_tensor.device, + dtype=input_tensor.dtype) + sampler = MockLogitsSampler(32000, fake_logits) + model_runner = ModelRunner(None, None, None) + + generation_model = GenerationMixin() + generation_config = GenerationConfig(top_k=top_k, + top_p=top_p, + do_sample=True) + warpers = generation_model._get_logits_warper(generation_config) + assert len(warpers) == 2 # top_p and top_k + + seq_group_metadata_list = [] + prompt_lens = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams( + temperature=1, + top_k=top_k, + top_p=top_p, + ), + block_tables={0: [1]}, + )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) + + sample_probs = None + + def mock_sample(probs, logprobs, sampling_metadata): + nonlocal sample_probs + sample_probs = probs + return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] + + with patch("vllm.model_executor.layers.sampler._sample", mock_sample): + sampler(embedding=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) + hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) + assert torch.allclose(hf_probs, sample_probs, atol=1e-5) + assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) + + del model_runner diff --git a/tests/worker/__init__.py b/tests/worker/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/worker/spec_decode/__init__.py b/tests/worker/spec_decode/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/worker/spec_decode/test_multi_step_worker.py b/tests/worker/spec_decode/test_multi_step_worker.py new file mode 100644 index 000000000000..ea5480290357 --- /dev/null +++ b/tests/worker/spec_decode/test_multi_step_worker.py @@ -0,0 +1,261 @@ +import torch +import random +import pytest +from unittest.mock import MagicMock + +from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker +from vllm.worker.worker import Worker +from vllm.model_executor.utils import set_random_seed + +from .utils import (create_execute_model_data, create_worker, + create_seq_group_metadata_from_prompts, zero_kv_cache, + patch_execute_model_with_seeds, + assert_logprobs_dict_allclose) + + +@pytest.mark.parametrize('num_steps', list(range(1, 17))) +def test_assert_enough_kv_space(num_steps: int): + """Test that the multi step worker checks for sufficient space in the KV + cache. It should throw if it cannot run all the steps. + """ + block_size = 16 + num_gpu_blocks = 2048 // block_size + + prompts = [ + list(range(block_size * 3)), + list(range(block_size * 2)), + ] + + prev_output_tokens = [ + list(range(block_size * 1)), + list(range(block_size * 2)), + ] + + final_seq_lens = [ + len(prompt + output) + num_steps + for prompt, output in zip(prompts, prev_output_tokens) + ] + + inputs = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_seq_lens, + continuations=prev_output_tokens) + + assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access + worker = MagicMock() + worker.model_runner.block_size = block_size + + for seq_group_metadata in inputs: + original_block_tables = seq_group_metadata.block_tables + + # No exception. + assert_enough_kv_space(worker, inputs, num_steps) + + seq_group_metadata.block_tables = { + seq_id: [] + for seq_id, physical_blocks in original_block_tables.items() + } + + # Expect exception. + with pytest.raises(ValueError, + match='times but found insufficient KV space for'): + assert_enough_kv_space(worker, inputs, num_steps) + + seq_group_metadata.block_tables = original_block_tables + + +@torch.inference_mode() +def test_same_output_for_single_step(): + """Verify the multi step worker produces the same output as the normal + worker for num_steps=1. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 32 + num_gpu_blocks = 2048 // block_size + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + multi_step_worker.model_runner = worker.model_runner + multi_step_worker.cache_engine = worker.cache_engine + + num_steps = 1 + + prompts = [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + ] + + final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + + multi_step_execute_model_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + single_step_execute_model_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + zero_kv_cache(multi_step_worker.cache_engine) + set_random_seed(seed) + actual_output = multi_step_worker.execute_model_multi_step( + **multi_step_execute_model_data.to_dict(), num_steps=num_steps) + assert len(actual_output) == num_steps + actual_output = actual_output[0] + + zero_kv_cache(worker.cache_engine) + set_random_seed(seed) + expected_output = worker.execute_model( + **single_step_execute_model_data.to_dict(), ) + + actual_token_ids = [ + output.samples[0].output_token for output in actual_output + ] + actual_logprobs = [output.samples[0].logprobs for output in actual_output] + + expected_token_ids = [ + output.samples[0].output_token for output in expected_output + ] + expected_logprobs = [ + output.samples[0].logprobs for output in expected_output + ] + + assert actual_token_ids == expected_token_ids + + print(f'{actual_logprobs=}') + print(f'{expected_logprobs=}') + assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs) + + +@torch.inference_mode() +def test_same_output_for_multi_step(): + """Verify the multi-step worker produces the same output as the normal + worker when num_steps > 1. This test runs the multi-step worker once, and + then runs the worker num_steps times, and compares the output. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 16 + num_gpu_blocks = 2048 // block_size + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + # Make sure we go over the block boundary. + num_steps = block_size + 1 + + random.seed(seed) + prompts = [[ + random.randint(0, 1000) for _ in range(random.randint(10, 20)) + ] for _ in range(10)] + + final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + + rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) + multi_step_worker.execute_model = patch_execute_model_with_seeds( + multi_step_worker, rand_seeds) + worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) + + continuations = [[1] for _ in prompts] + execute_model_data = create_execute_model_data( + create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_seq_lens=final_seq_lens), ) + + # Run multi-step. + zero_kv_cache(multi_step_worker.cache_engine) + set_random_seed(seed) + multi_step_output = multi_step_worker.execute_model_multi_step( + **execute_model_data.to_dict(), num_steps=num_steps) + + # Run single-step repeatedly. + zero_kv_cache(worker.cache_engine) + single_step_output = [] + continuations = [[1] for _ in prompts] + set_random_seed(seed) + + for _ in multi_step_output: + + execute_model_data = create_execute_model_data( + create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_seq_lens=final_seq_lens)) + + single_step_output.append( + worker.execute_model(**execute_model_data.to_dict(), )) + + # Append output tokens to new sequence data. + for i, seq_group_output in enumerate(single_step_output[-1]): + continuations[i].append(seq_group_output.samples[0].output_token) + + # Get token ids and logprobs for comparison. + multi_step_output_logprobs = [[] for _ in prompts] + single_step_output_logprobs = [[] for _ in prompts] + + multi_step_output_token_ids = [[] for _ in prompts] + single_step_output_token_ids = [[] for _ in prompts] + for i, _ in enumerate(prompts): + for multi_step, single_step in zip(multi_step_output, + single_step_output): + multi_step_output_token_ids[i].append( + multi_step[i].samples[0].output_token) + single_step_output_token_ids[i].append( + single_step[i].samples[0].output_token) + + multi_step_output_logprobs[i].append( + multi_step[i].samples[0].logprobs) + single_step_output_logprobs[i].append( + single_step[i].samples[0].logprobs) + + # Print per-sequence token ids + for i, (multi_step_tokens, single_step_tokens) in enumerate( + zip(multi_step_output_token_ids, single_step_output_token_ids)): + print(f'{i=} {multi_step_tokens=}') + print(f'{i=} {single_step_tokens=}') + print(f'{i=} equal {multi_step_tokens == single_step_tokens}') + + # Assert token ids are equal. + for multi_step_tokens, single_step_tokens in zip( + multi_step_output_token_ids, single_step_output_token_ids): + assert multi_step_tokens == single_step_tokens + + # Assert logprobs are equal. + for multi_step_logprobs, single_step_logprobs in zip( + multi_step_output_logprobs, single_step_output_logprobs): + assert_logprobs_dict_allclose(multi_step_logprobs, + single_step_logprobs) diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py new file mode 100644 index 000000000000..812033829394 --- /dev/null +++ b/tests/worker/spec_decode/utils.py @@ -0,0 +1,177 @@ +import torch +from typing import List, Optional, Dict + +from vllm.worker.worker import Worker +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.engine.arg_utils import EngineArgs +from vllm.sequence import SequenceGroupMetadata, SequenceData +from vllm.sampling_params import SamplingParams +from vllm.worker.cache_engine import CacheEngine +from vllm.model_executor.utils import set_random_seed +from dataclasses import dataclass, fields + + +@dataclass +class ExecuteModelData: + """Helper data structure which facilitates cleaner tests. + """ + seq_group_metadata_list: List[SequenceGroupMetadata] + blocks_to_swap_in: Dict[int, int] + blocks_to_swap_out: Dict[int, int] + blocks_to_copy: Dict[int, List[int]] + + def to_dict(self): + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + +def round_up_to_next_block(seq_len: int, block_size: int) -> int: + return (seq_len + block_size - 1) // block_size + + +def create_execute_model_data( + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, int]] = None, +) -> ExecuteModelData: + if blocks_to_swap_in is None: + blocks_to_swap_in = {} + if blocks_to_swap_out is None: + blocks_to_swap_out = {} + if blocks_to_copy is None: + blocks_to_copy = {} + + return ExecuteModelData( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + + +def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): + seed_iter = iter(rand_seeds) + original_execute_model = worker.execute_model + + def new_execute_model(*args, **kwargs): + result = original_execute_model(*args, **kwargs) + set_random_seed(next(seed_iter)) + return result + + return new_execute_model + + +def zero_kv_cache(cache_engine: CacheEngine): + assert cache_engine.gpu_cache + for key_blocks, value_blocks in cache_engine.gpu_cache: + key_blocks.zero_() + value_blocks.zero_() + + +def create_worker(cls: type, + model_name: str, + block_size: int, + num_gpu_blocks: int, + seed: int, + is_driver_worker: bool = True, + enforce_eager: bool = True): + engine_args = EngineArgs( + model=model_name, + seed=seed, + block_size=block_size, + enforce_eager=enforce_eager, + ) + + (model_config, cache_config, parallel_config, + scheduler_config) = engine_args.create_engine_configs() + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + worker = cls( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + + worker.init_model() + worker.load_model() + + cache_config.num_gpu_blocks = num_gpu_blocks + cache_config.num_cpu_blocks = 0 + worker.init_cache_engine(cache_config) + worker.warm_up_model() + + return worker + + +def create_seq_group_metadata_from_prompts( + prompts: List[List[int]], + num_gpu_blocks: int, + block_size: int, + final_seq_lens: List[int], + continuations: Optional[List[List[int]]] = None, + num_tokens_processed: Optional[List[int]] = None, + seq_ids: Optional[List[int]] = None, +) -> List[SequenceGroupMetadata]: + + if continuations is None: + continuations = [[] for _ in prompts] + + if num_tokens_processed is None: + # Default to 1 token missing from kv cache for generation sequences. + num_tokens_processed = [] + for continuation, prompt in zip(continuations, prompts): + # If prefill, then default to zero tokens processed. + if not continuation: + num_tokens_processed.append(0) + else: + # If generation, then default to all but one tokens processed. + num_tokens_processed.append( + len(continuation) + len(prompt) - 1) + + if seq_ids is None: + seq_ids = list(i for i, _ in enumerate(prompts)) + + free_gpu_blocks = list(range(num_gpu_blocks)) + + block_allocations = { + i: [ + free_gpu_blocks.pop() + for _ in range(round_up_to_next_block(final_len, block_size)) + ] + for i, final_len in enumerate(final_seq_lens) + } + + return [ + SequenceGroupMetadata( + request_id=str(i), + is_prompt=len(cont_token_ids) == 0, + seq_data={ + i: + SequenceData(prompt_token_ids=prompt_token_ids[:] + + cont_token_ids[:]) + }, + sampling_params=SamplingParams(temperature=0.0, ), + block_tables={i: block_allocations[i][:]}, + ) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in + enumerate(zip(prompts, continuations, num_tokens_processed)) + ] + + +def assert_logprobs_dict_allclose( + actual_logprobs: List[Dict[int, float]], + expected_logprobs: List[Dict[int, float]]) -> None: + for single_step_actual_logprobs, single_step_expected_logprobs in zip( + actual_logprobs, expected_logprobs): + assert set(single_step_actual_logprobs.keys()) == set( + single_step_expected_logprobs.keys()) + for token_id in single_step_actual_logprobs: + actual = torch.tensor(single_step_actual_logprobs[token_id]) + expected = torch.tensor(single_step_expected_logprobs[token_id]) + assert torch.allclose(actual, expected) diff --git a/vllm/core/policy.py b/vllm/core/policy.py index 3beabb1006a6..99f183b42c8b 100644 --- a/vllm/core/policy.py +++ b/vllm/core/policy.py @@ -1,4 +1,5 @@ -from typing import List +from collections import deque +from typing import Deque from vllm.sequence import SequenceGroup @@ -15,13 +16,14 @@ def get_priority( def sort_by_priority( self, now: float, - seq_groups: List[SequenceGroup], - ) -> List[SequenceGroup]: - return sorted( - seq_groups, - key=lambda seq_group: self.get_priority(now, seq_group), - reverse=True, - ) + seq_groups: Deque[SequenceGroup], + ) -> Deque[SequenceGroup]: + return deque( + sorted( + seq_groups, + key=lambda seq_group: self.get_priority(now, seq_group), + reverse=True, + )) class FCFS(Policy): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 398585a88fb5..9fe01a14aedc 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,7 @@ +from collections import deque import enum import time -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union from vllm.config import CacheConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager @@ -29,7 +30,7 @@ class SchedulerOutputs: def __init__( self, - scheduled_seq_groups: List[SequenceGroup], + scheduled_seq_groups: Iterable[SequenceGroup], prompt_run: bool, num_batched_tokens: int, blocks_to_swap_in: Dict[int, int], @@ -75,38 +76,52 @@ def __init__( num_cpu_blocks=self.cache_config.num_cpu_blocks, sliding_window=self.cache_config.sliding_window) - # TODO(zhuohan): Use deque instead of list for better performance. # Sequence groups in the WAITING state. - self.waiting: List[SequenceGroup] = [] + self.waiting: Deque[SequenceGroup] = deque() # Sequence groups in the RUNNING state. - self.running: List[SequenceGroup] = [] + self.running: Deque[SequenceGroup] = deque() # Sequence groups in the SWAPPED state. - self.swapped: List[SequenceGroup] = [] + self.swapped: Deque[SequenceGroup] = deque() def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a sequence group with the given ID. + + Check if the sequence group with the given ID + is present in any of the state queue. + If present, remove the sequence group from the state queue. + Also, if any of the sequences in the sequence group is not finished, + free the sequence with status `FINISHED_ABORTED`. + Otherwise, do nothing. + + Args: + request_id: The ID(s) of the sequence group to abort. + """ if isinstance(request_id, str): request_id = (request_id, ) request_ids = set(request_id) for state_queue in [self.waiting, self.running, self.swapped]: - # We need to reverse the list as we are removing elements - # from it as we iterate over it. If we don't do it, - # indices will get messed up and we will skip over elements. - for seq_group in reversed(state_queue): + aborted_groups = [] + for seq_group in state_queue: + if not request_ids: + # Using 'break' here may add two extra iterations, + # but is acceptable to reduce complexity . + break if seq_group.request_id in request_ids: - # Remove the sequence group from the state queue. - state_queue.remove(seq_group) - for seq in seq_group.get_seqs(): - if seq.is_finished(): - continue - seq.status = SequenceStatus.FINISHED_ABORTED - self.free_seq(seq) + # Appending aborted group into pending list. + aborted_groups.append(seq_group) request_ids.remove(seq_group.request_id) - if not request_ids: - return + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + for seq in seq_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) def has_unfinished_seqs(self) -> bool: return self.waiting or self.running or self.swapped @@ -152,7 +167,7 @@ def _schedule(self) -> SchedulerOutputs: for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + self.waiting.popleft() continue # If the sequence group cannot be allocated, stop. @@ -166,7 +181,7 @@ def _schedule(self) -> SchedulerOutputs: for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + self.waiting.popleft() continue # If the number of batched tokens exceeds the limit, stop. @@ -188,7 +203,7 @@ def _schedule(self) -> SchedulerOutputs: break seq_lens = new_seq_lens - seq_group = self.waiting.pop(0) + seq_group = self.waiting.popleft() self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs @@ -214,14 +229,14 @@ def _schedule(self) -> SchedulerOutputs: self.running = self.policy.sort_by_priority(now, self.running) # Reserve new token slots for the running sequence groups. - running: List[SequenceGroup] = [] + running: Deque[SequenceGroup] = deque() preempted: List[SequenceGroup] = [] while self.running: - seq_group = self.running.pop(0) + seq_group = self.running.popleft() while not self.block_manager.can_append_slot(seq_group): if self.running: # Preempt the lowest-priority sequence groups. - victim_seq_group = self.running.pop(-1) + victim_seq_group = self.running.pop() self._preempt(victim_seq_group, blocks_to_swap_out) preempted.append(victim_seq_group) else: @@ -255,7 +270,7 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - seq_group = self.swapped.pop(0) + seq_group = self.swapped.popleft() self._swap_in(seq_group, blocks_to_swap_in) self._append_slot(seq_group, blocks_to_copy) num_curr_seqs += num_new_seqs @@ -376,7 +391,7 @@ def _preempt_by_recompute( self.block_manager.free(seq) # NOTE: For FCFS, we insert the preempted sequence group to the front # of the waiting queue. - self.waiting.insert(0, seq_group) + self.waiting.appendleft(seq_group) def _preempt_by_swap( self, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 92f23ec29bfd..8a5b00ca7f7c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -253,7 +253,8 @@ class AsyncLLMEngine: log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args, *kwargs: Arguments for LLMEngine. + *args: Arguments for LLMEngine. + *kwargs: Arguments for LLMEngine. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine @@ -428,6 +429,49 @@ async def generate( Yields: The output `RequestOutput` objects from the LLMEngine for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "prompt": "What is LLM?", + >>> "stream": False, # assume the non-streaming case + >>> "temperature": 0.0, + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.generate( + >>> example_input["prompt"], + >>> SamplingParams(temperature=example_input["temperature"]), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... """ # Preprocess the request. # This should not be used for logging, as it is monotonic time. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bf5a11f39e82..4bfcb2577024 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -18,7 +18,7 @@ SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) -from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port +from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method if ray: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -132,7 +132,8 @@ def _init_workers(self): "Ray is required if parallel_config.world_size > 1.") self.workers: List[Worker] = [] - distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) self.driver_worker = Worker( self.model_config, self.parallel_config, @@ -207,7 +208,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): worker.set_cuda_visible_devices.remote(node_gpus[node_id]) - distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port) # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -257,7 +259,26 @@ def _verify_args(self) -> None: self.cache_config.verify_with_parallel_config(self.parallel_config) def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache.""" + """Profiles the memory usage and initializes the KV cache. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + More details can be found in the + :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method + from class :class:`~vllm.worker.Worker`. + + Afterwards, as there may be multiple workers, + we take the minimum number of blocks across all workers + to ensure this can be applied to all of them. + + Finally, the engine will initialize the KV cache + with the calculated number of blocks. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameters. + """ # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers( "profile_num_available_blocks", @@ -334,6 +355,30 @@ def add_request( use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `best_of` number of :class:`~vllm.Sequence` objects. + - Create a :class:`~vllm.SequenceGroup` object + from the list of :class:`~vllm.Sequence`. + - Add the :class:`~vllm.SequenceGroup` object to the scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... """ if arrival_time is None: arrival_time = time.monotonic() @@ -358,6 +403,17 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: Args: request_id: The ID(s) of the request to abort. + + Details: + - Refer to the + :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` + from class :class:`~vllm.core.scheduler.Scheduler`. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) """ self.scheduler.abort_seq_group(request_id) @@ -601,8 +657,10 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in (scheduled_seq_groups + - scheduler_outputs.ignored_seq_groups): + for seq_group in scheduled_seq_groups: + request_output = RequestOutput.from_seq_group(seq_group) + request_outputs.append(request_output) + for seq_group in scheduler_outputs.ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) @@ -615,11 +673,53 @@ def _process_model_outputs( def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. - This function performs one decoding iteration of the engine. It first - schedules the sequences to be executed in the next iteration and the - token blocks to be swapped in/out/copy. Then, it executes the model - and updates the scheduler with the model outputs. Finally, it decodes - the sequences and returns the newly generated results. + .. figure:: https://i.imgur.com/sv2HssD.png + :alt: Overview of the step function + :align: center + + Overview of the step function. + + Details: + - Step 1: Schedules the sequences to be executed in the next + iteration and the token blocks to be swapped in/out/copy. + + - Depending on the scheduling policy, + sequences may be `preempted/reordered`. + - A Sequence Group (SG) refer to a group of sequences + that are generated from the same prompt. + + - Step 2: Calls the workers to execute the model. + - Step 3: Processes the model output. This mainly includes: + + - Decodes the relevant outputs. + - Updates the scheduled sequence groups with model outputs + based on its `sampling parameters` (`use_beam_search` or not). + - Frees the finished sequence groups. + + - Finally, it creates and returns the newly generated results. + + Example: + >>> # Please see the example/ folder for more detailed examples. + >>> + >>> # initialize engine and request arguments + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> example_inputs = [(0, "What is LLM?", + >>> SamplingParams(temperature=0.0))] + >>> + >>> # Start the engine with an event loop + >>> while True: + >>> if example_inputs: + >>> req_id, prompt, sampling_params = example_inputs.pop(0) + >>> engine.add_request(str(req_id), prompt, sampling_params) + >>> + >>> # continue the request processing + >>> request_outputs = engine.step() + >>> for request_output in request_outputs: + >>> if request_output.finished: + >>> # return or show the request output + >>> + >>> if not (engine.has_unfinished_requests() or example_inputs): + >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index fb8854e068c8..1cb5fcda344f 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -65,10 +65,9 @@ def initialize_cluster( the default Ray cluster address. Returns: - A tuple of (`distributed_init_method`, `placement_group`). The - `distributed_init_method` is the address for initializing the - distributed backend. `placement_group` includes the specification - of the resources for each distributed worker. + An optional `PlacementGroup`. It includes the specification + of the resources for each distributed worker. None if Ray is + not used. """ if parallel_config.worker_use_ray or engine_use_ray: if ray is None: diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 73f7890ef7d3..9c27bcf2636c 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -74,12 +74,18 @@ async def stream_results() -> AsyncGenerator[bytes, None]: parser.add_argument("--port", type=int, default=8000) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) + app.root_path = args.root_path uvicorn.run(app, host=args.host, port=args.port, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 30c55f4c01c5..d652045c8ad7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,19 +1,12 @@ -# Adapted from -# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py - import argparse import asyncio -import codecs import json -import time from contextlib import asynccontextmanager -from http import HTTPStatus -from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union - from aioprometheus import MetricsMiddleware from aioprometheus.asgi.starlette import metrics import fastapi import uvicorn +from http import HTTPStatus from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -22,26 +15,16 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import add_global_metrics_labels -from vllm.entrypoints.openai.protocol import ( - CompletionRequest, CompletionResponse, CompletionResponseChoice, - CompletionResponseStreamChoice, CompletionStreamResponse, - ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) +from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse from vllm.logger import init_logger -from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import random_uuid +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion TIMEOUT_KEEP_ALIVE = 5 # seconds +openai_serving_chat: OpenAIServingChat = None +openai_serving_completion: OpenAIServingCompletion = None logger = init_logger(__name__) -served_model = None -engine_args = None -engine = None -response_role = None @asynccontextmanager @@ -106,6 +89,11 @@ def parse_args(): type=str, default=None, help="The file path to the SSL cert file") + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") parser = AsyncEngineArgs.add_cli_args(parser) return parser.parse_args() @@ -115,72 +103,10 @@ def parse_args(): app.add_route("/metrics", metrics) # Exposes HTTP metrics -def create_error_response(status_code: HTTPStatus, - message: str) -> JSONResponse: - return JSONResponse(ErrorResponse(message=message, - type="invalid_request_error").dict(), - status_code=status_code.value) - - -def load_chat_template(args, tokenizer): - if args.chat_template is not None: - try: - with open(args.chat_template, "r") as f: - chat_template = f.read() - except OSError: - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - chat_template = codecs.decode(args.chat_template, "unicode_escape") - - tokenizer.chat_template = chat_template - logger.info( - f"Using supplied chat template:\n{tokenizer.chat_template}") - elif tokenizer.chat_template is not None: - logger.info(f"Using default chat template:\n{tokenizer.chat_template}") - else: - logger.warning("No chat template provided. Chat API will not work.") - - @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) - - -async def check_model(request) -> Optional[JSONResponse]: - if request.model == served_model: - return - ret = create_error_response( - HTTPStatus.NOT_FOUND, - f"The model `{request.model}` does not exist.", - ) - return ret - - -async def check_length( - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None -) -> Tuple[List[int], Optional[JSONResponse]]: - assert (not (prompt is None and prompt_ids is None) - and not (prompt is not None and prompt_ids is not None) - ), "Either prompt or prompt_ids should be provided." - input_ids = prompt_ids if prompt_ids is not None else tokenizer( - prompt).input_ids - token_num = len(input_ids) - - if request.max_tokens is None: - request.max_tokens = max_model_len - token_num - if token_num + request.max_tokens > max_model_len: - return input_ids, create_error_response( - HTTPStatus.BAD_REQUEST, - f"This model's maximum context length is {max_model_len} tokens. " - f"However, you requested {request.max_tokens + token_num} tokens " - f"({token_num} in the messages, " - f"{request.max_tokens} in the completion). " - f"Please reduce the length of the messages or completion.", - ) - else: - return input_ids, None + err = openai_serving_chat.create_error_response(message=str(exc)) + return JSONResponse(err.dict(), status_code=HTTPStatus.BAD_REQUEST) @app.get("/health") @@ -191,544 +117,31 @@ async def health() -> Response: @app.get("/v1/models") async def show_available_models(): - """Show available models. Right now we only have one model.""" - model_cards = [ - ModelCard(id=served_model, - root=served_model, - permission=[ModelPermission()]) - ] - return ModelList(data=model_cards) - - -def create_logprobs( - token_ids: List[int], - top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, - num_output_top_logprobs: Optional[int] = None, - initial_text_offset: int = 0, -) -> LogProbs: - """Create OpenAI-style logprobs.""" - logprobs = LogProbs() - last_token_len = 0 - if num_output_top_logprobs: - logprobs.top_logprobs = [] - for i, token_id in enumerate(token_ids): - step_top_logprobs = top_logprobs[i] - if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id] - else: - token_logprob = None - token = tokenizer.convert_ids_to_tokens(token_id) - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) - if len(logprobs.text_offset) == 0: - logprobs.text_offset.append(initial_text_offset) - else: - logprobs.text_offset.append(logprobs.text_offset[-1] + - last_token_len) - last_token_len = len(token) - - if num_output_top_logprobs: - logprobs.top_logprobs.append({ - tokenizer.convert_ids_to_tokens(i): p - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) - return logprobs + models = await openai_serving_chat.show_available_models() + return JSONResponse(content=models.dict()) @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): - """Completion API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/chat/create - for the API specification. This API mimics the OpenAI ChatCompletion API. - - NOTE: Currently we do not support the following features: - - function_call (Users should implement this by themselves) - - logit_bias (to be supported by vLLM engine) - """ - error_check_ret = await check_model(request) - if error_check_ret is not None: - return error_check_ret - - if request.logit_bias is not None and len(request.logit_bias) > 0: - # TODO: support logit_bias in vLLM engine. - return create_error_response(HTTPStatus.BAD_REQUEST, - "logit_bias is not currently supported") - - try: - prompt = tokenizer.apply_chat_template( - conversation=request.messages, - tokenize=False, - add_generation_prompt=request.add_generation_prompt) - except Exception as e: - logger.error(f"Error in applying chat template from request: {str(e)}") - return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - - token_ids, error_check_ret = await check_length(request, prompt=prompt) - if error_check_ret is not None: - return error_check_ret - - model_name = request.model - request_id = f"cmpl-{random_uuid()}" - created_time = int(time.monotonic()) - chunk_object_type = "chat.completion.chunk" - try: - spaces_between_special_tokens = request.spaces_between_special_tokens - sampling_params = SamplingParams( - n=request.n, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - repetition_penalty=request.repetition_penalty, - temperature=request.temperature, - top_p=request.top_p, - min_p=request.min_p, - stop=request.stop, - stop_token_ids=request.stop_token_ids, - max_tokens=request.max_tokens, - best_of=request.best_of, - top_k=request.top_k, - ignore_eos=request.ignore_eos, - use_beam_search=request.use_beam_search, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - except ValueError as e: - return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - - result_generator = engine.generate(prompt, sampling_params, request_id, - token_ids) - - def get_role() -> str: - if request.add_generation_prompt: - return response_role - else: - return request.messages[-1]["role"] - - async def completion_stream_generator() -> AsyncGenerator[str, None]: - # Send first response for each request.n (index) with the role - role = get_role() - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=DeltaMessage(role=role), finish_reason=None) - chunk = ChatCompletionStreamResponse(id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.json(exclude_unset=True, ensure_ascii=False) - yield f"data: {data}\n\n" - - # Send response to echo the input portion of the last message - if request.echo: - last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] - if last_msg_content: - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=last_msg_content), - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.json(exclude_unset=True, ensure_ascii=False) - yield f"data: {data}\n\n" - - # Send response for each token for each request.n (index) - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - finish_reason_sent = [False] * request.n - async for res in result_generator: - res: RequestOutput - for output in res.outputs: - i = output.index - - if finish_reason_sent[i]: - continue - - if output.finish_reason is None: - # Send token-by-token response for each request.n - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=delta_text), - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.json(exclude_unset=True, ensure_ascii=False) - yield f"data: {data}\n\n" - else: - # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], - ) - choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=[], finish_reason=output.finish_reason) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if final_usage is not None: - chunk.usage = final_usage - data = chunk.json(exclude_unset=True, - exclude_none=True, - ensure_ascii=False) - yield f"data: {data}\n\n" - finish_reason_sent[i] = True - # Send the final done message after all response.n are finished - yield "data: [DONE]\n\n" - - async def completion_full_generator(): - final_res: RequestOutput = None - async for res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - return create_error_response(HTTPStatus.BAD_REQUEST, - "Client disconnected") - final_res = res - assert final_res is not None - - choices = [] - role = get_role() - for output in final_res.outputs: - choice_data = ChatCompletionResponseChoice( - index=output.index, - message=ChatMessage(role=role, content=output.text), - finish_reason=output.finish_reason, - ) - choices.append(choice_data) - - if request.echo: - last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] - - for choice in choices: - full_message = last_msg_content + choice.message.content - choice.message.content = full_message - - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - response = ChatCompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) - - return response - - # Streaming response - if request.stream: - return StreamingResponse(completion_stream_generator(), + generator = await openai_serving_chat.create_chat_completion( + request, raw_request) + if request.stream and not isinstance(generator, ErrorResponse): + return StreamingResponse(content=generator, media_type="text/event-stream") else: - return await completion_full_generator() + return JSONResponse(content=generator.dict()) @app.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - """Completion API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/completions/create - for the API specification. This API mimics the OpenAI Completion API. - - NOTE: Currently we do not support the following features: - - suffix (the language models we currently support do not support - suffix) - - logit_bias (to be supported by vLLM engine) - """ - - error_check_ret = await check_model(request) - if error_check_ret is not None: - return error_check_ret - - # OpenAI API supports echoing the prompt when max_tokens is 0. - echo_without_generation = request.echo and request.max_tokens == 0 - - if request.suffix is not None: - # The language models we currently support do not support suffix. - return create_error_response(HTTPStatus.BAD_REQUEST, - "suffix is not currently supported") - - if request.logit_bias is not None and len(request.logit_bias) > 0: - # TODO: support logit_bias in vLLM engine. - return create_error_response(HTTPStatus.BAD_REQUEST, - "logit_bias is not currently supported") - - model_name = request.model - request_id = f"cmpl-{random_uuid()}" - - use_token_ids = False - if isinstance(request.prompt, list): - if len(request.prompt) == 0: - return create_error_response(HTTPStatus.BAD_REQUEST, - "please provide at least one prompt") - first_element = request.prompt[0] - if isinstance(first_element, int): - use_token_ids = True - prompt = request.prompt - elif isinstance(first_element, (str, list)): - # TODO: handles multiple prompt case in list[list[int]] - if len(request.prompt) > 1: - return create_error_response( - HTTPStatus.BAD_REQUEST, - "multiple prompts in a batch is not currently supported") - use_token_ids = not isinstance(first_element, str) - prompt = request.prompt[0] - else: - prompt = request.prompt - - if use_token_ids: - _, error_check_ret = await check_length(request, prompt_ids=prompt) - else: - token_ids, error_check_ret = await check_length(request, prompt=prompt) - if error_check_ret is not None: - return error_check_ret - - created_time = int(time.monotonic()) - try: - spaces_between_special_tokens = request.spaces_between_special_tokens - sampling_params = SamplingParams( - n=request.n, - best_of=request.best_of, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - repetition_penalty=request.repetition_penalty, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - min_p=request.min_p, - stop=request.stop, - stop_token_ids=request.stop_token_ids, - ignore_eos=request.ignore_eos, - max_tokens=request.max_tokens - if not echo_without_generation else 1, - logprobs=request.logprobs, - use_beam_search=request.use_beam_search, - prompt_logprobs=request.logprobs if request.echo else None, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - except ValueError as e: - return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - - if use_token_ids: - result_generator = engine.generate(None, - sampling_params, - request_id, - prompt_token_ids=prompt) - else: - result_generator = engine.generate(prompt, sampling_params, request_id, - token_ids) - - # Similar to the OpenAI API, when n != best_of, we do not stream the - # results. In addition, we do not stream the results when use beam search. - stream = (request.stream - and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search) - - def create_stream_response_json( - index: int, - text: str, - logprobs: Optional[LogProbs] = None, - finish_reason: Optional[str] = None, - usage: Optional[UsageInfo] = None, - ) -> str: - choice_data = CompletionResponseStreamChoice( - index=index, - text=text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - response = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - ) - if usage is not None: - response.usage = usage - response_json = response.json(exclude_unset=True, ensure_ascii=False) - - return response_json - - async def completion_stream_generator() -> AsyncGenerator[str, None]: - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - has_echoed = [False] * request.n - async for res in result_generator: - res: RequestOutput - for output in res.outputs: - i = output.index - delta_text = output.text[len(previous_texts[i]):] - token_ids = output.token_ids[previous_num_tokens[i]:] - if request.logprobs is not None: - top_logprobs = output.logprobs[previous_num_tokens[i]:] - else: - top_logprobs = None - offsets = len(previous_texts[i]) - if request.echo and not has_echoed[i]: - if not echo_without_generation: - delta_text = res.prompt + delta_text - token_ids = res.prompt_token_ids + token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs + top_logprobs - else: # only just return the prompt - delta_text = res.prompt - token_ids = res.prompt_token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs - has_echoed[i] = True - if request.logprobs is not None: - logprobs = create_logprobs( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=offsets, - ) - else: - logprobs = None - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - finish_reason = output.finish_reason - response_json = create_stream_response_json( - index=i, - text=delta_text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - yield f"data: {response_json}\n\n" - if output.finish_reason is not None: - logprobs = (LogProbs() - if request.logprobs is not None else None) - prompt_tokens = len(res.prompt_token_ids) - completion_tokens = len(output.token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - response_json = create_stream_response_json( - index=i, - text="", - logprobs=logprobs, - finish_reason=output.finish_reason, - usage=final_usage, - ) - yield f"data: {response_json}\n\n" - yield "data: [DONE]\n\n" - - # Streaming response - if stream: - return StreamingResponse(completion_stream_generator(), + generator = await openai_serving_completion.create_completion( + request, raw_request) + if request.stream and not isinstance(generator, ErrorResponse): + return StreamingResponse(content=generator, media_type="text/event-stream") - - # Non-streaming response - final_res: RequestOutput = None - async for res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - return create_error_response(HTTPStatus.BAD_REQUEST, - "Client disconnected") - final_res = res - assert final_res is not None - choices = [] - prompt_token_ids = final_res.prompt_token_ids - prompt_logprobs = final_res.prompt_logprobs - prompt_text = final_res.prompt - for output in final_res.outputs: - if request.logprobs is not None: - if not echo_without_generation: - token_ids = output.token_ids - top_logprobs = output.logprobs - if request.echo: - token_ids = prompt_token_ids + token_ids - top_logprobs = prompt_logprobs + top_logprobs - else: - token_ids = prompt_token_ids - top_logprobs = prompt_logprobs - logprobs = create_logprobs( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - ) - else: - logprobs = None - if not echo_without_generation: - output_text = output.text - if request.echo: - output_text = prompt_text + output_text - else: - output_text = prompt_text - choice_data = CompletionResponseChoice( - index=output.index, - text=output_text, - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - choices.append(choice_data) - - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - response = CompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) - - if request.stream: - # When user requests streaming but we don't stream, we still need to - # return a streaming response with a single event. - response_json = response.json(ensure_ascii=False) - - async def fake_stream_generator() -> AsyncGenerator[str, None]: - yield f"data: {response_json}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(fake_stream_generator(), - media_type="text/event-stream") - - return response + else: + return JSONResponse(content=generator.dict()) if __name__ == "__main__": @@ -749,23 +162,17 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: else: served_model = args.model - response_role = args.response_role - engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) - engine_model_config = asyncio.run(engine.get_model_config()) - max_model_len = engine_model_config.max_model_len - - # A separate tokenizer to map token IDs to strings. - tokenizer = get_tokenizer( - engine_model_config.tokenizer, - tokenizer_mode=engine_model_config.tokenizer_mode, - trust_remote_code=engine_model_config.trust_remote_code) - load_chat_template(args, tokenizer) + openai_serving_chat = OpenAIServingChat(engine, served_model, + args.response_role, + args.chat_template) + openai_serving_completion = OpenAIServingCompletion(engine, served_model) # Register labels for metrics add_global_metrics_labels(model_name=engine_args.model) + app.root_path = args.root_path uvicorn.run(app, host=args.host, port=args.port, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py new file mode 100644 index 000000000000..9b843a94de10 --- /dev/null +++ b/vllm/entrypoints/openai/serving_chat.py @@ -0,0 +1,288 @@ +import time +import codecs +from fastapi import Request +from typing import AsyncGenerator, AsyncIterator, Union +from vllm.logger import init_logger +from vllm.utils import random_uuid +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + UsageInfo) +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.entrypoints.openai.serving_engine import OpenAIServing + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + response_role: str, + chat_template=None): + super().__init__(engine=engine, served_model=served_model) + self.response_role = response_role + self._load_chat_template(chat_template) + + async def create_chat_completion( + self, request: ChatCompletionRequest, raw_request: Request + ) -> Union[ErrorResponse, AsyncGenerator[str, None], + ChatCompletionResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (to be supported by vLLM engine) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.logit_bias is not None and len(request.logit_bias) > 0: + # TODO: support logit_bias in vLLM engine. + return self.create_error_response( + "logit_bias is not currently supported") + + try: + prompt = self.tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + add_generation_prompt=request.add_generation_prompt) + except Exception as e: + logger.error( + f"Error in applying chat template from request: {str(e)}") + return self.create_error_response(str(e)) + + token_ids, error_check_ret = await self._check_length(request, + prompt=prompt) + if error_check_ret is not None: + return error_check_ret + + request_id = f"cmpl-{random_uuid()}" + try: + spaces_between_special_tokens = request.spaces_between_special_tokens + sampling_params = SamplingParams( + n=request.n, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + repetition_penalty=request.repetition_penalty, + temperature=request.temperature, + top_p=request.top_p, + min_p=request.min_p, + stop=request.stop, + stop_token_ids=request.stop_token_ids, + max_tokens=request.max_tokens, + best_of=request.best_of, + top_k=request.top_k, + ignore_eos=request.ignore_eos, + use_beam_search=request.use_beam_search, + skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + except ValueError as e: + return self.create_error_response(str(e)) + + result_generator = self.engine.generate(prompt, sampling_params, + request_id, token_ids) + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, result_generator, request_id) + else: + return await self.chat_completion_full_generator( + request, raw_request, result_generator, request_id) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + else: + return request.messages[-1].role + + async def chat_completion_stream_generator( + self, request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], request_id: str + ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: + + model_name = request.model + created_time = int(time.monotonic()) + chunk_object_type = "chat.completion.chunk" + + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request) + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, delta=DeltaMessage(role=role), finish_reason=None) + chunk = ChatCompletionStreamResponse(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + + # Send response for each token for each request.n (index) + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + + if finish_reason_sent[i]: + continue + + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + previous_num_tokens[i], + ) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if final_usage is not None: + chunk.usage = final_usage + data = chunk.json(exclude_unset=True, + exclude_none=True, + ensure_ascii=False) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, request: ChatCompletionRequest, raw_request: Request, + result_generator: AsyncIterator[RequestOutput], + request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + + model_name = request.model + created_time = int(time.monotonic()) + final_res: RequestOutput = None + + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return self.create_error_response("Client disconnected") + final_res = res + assert final_res is not None + + choices = [] + role = self.get_chat_request_role(request) + for output in final_res.outputs: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role=role, content=output.text), + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + + for choice in choices: + full_message = last_msg_content + choice.message.content + choice.message.content = full_message + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + def _load_chat_template(self, chat_template): + if chat_template is not None: + try: + with open(chat_template, "r") as f: + self.tokenizer.chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + self.tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logger.info( + f"Using supplied chat template:\n{self.tokenizer.chat_template}" + ) + elif self.tokenizer.chat_template is not None: + logger.info( + f"Using default chat template:\n{self.tokenizer.chat_template}" + ) + else: + logger.warning( + "No chat template provided. Chat API will not work.") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py new file mode 100644 index 000000000000..d842d1a2a919 --- /dev/null +++ b/vllm/entrypoints/openai/serving_completion.py @@ -0,0 +1,295 @@ +import time +from fastapi import Request +from typing import AsyncGenerator, Optional +from vllm.logger import init_logger +from vllm.utils import random_uuid +from vllm.engine.async_llm_engine import AsyncLLMEngine +from .protocol import (CompletionRequest, CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, LogProbs, UsageInfo) +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.entrypoints.openai.serving_engine import OpenAIServing + +logger = init_logger(__name__) + + +class OpenAIServingCompletion(OpenAIServing): + + def __init__(self, engine: AsyncLLMEngine, served_model: str): + super().__init__(engine=engine, served_model=served_model) + + async def create_completion(self, request: CompletionRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following features: + - suffix (the language models we currently support do not support + suffix) + - logit_bias (to be supported by vLLM engine) + """ + + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # OpenAI API supports echoing the prompt when max_tokens is 0. + echo_without_generation = request.echo and request.max_tokens == 0 + + if request.suffix is not None: + # The language models we currently support do not support suffix. + return self.create_error_response( + "suffix is not currently supported") + + if request.logit_bias is not None and len(request.logit_bias) > 0: + # TODO: support logit_bias in vLLM engine. + return self.create_error_response( + "logit_bias is not currently supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + + use_token_ids = False + if isinstance(request.prompt, list): + if len(request.prompt) == 0: + return self.create_error_response( + "please provide at least one prompt") + first_element = request.prompt[0] + if isinstance(first_element, int): + use_token_ids = True + prompt = request.prompt + elif isinstance(first_element, (str, list)): + # TODO: handles multiple prompt case in list[list[int]] + if len(request.prompt) > 1: + return self.create_error_response( + "multiple prompts in a batch is not currently supported" + ) + use_token_ids = not isinstance(first_element, str) + prompt = request.prompt[0] + else: + prompt = request.prompt + + if use_token_ids: + _, error_check_ret = await self._check_length(request, + prompt_ids=prompt) + else: + token_ids, error_check_ret = await self._check_length( + request, prompt=prompt) + if error_check_ret is not None: + return error_check_ret + + created_time = int(time.monotonic()) + try: + spaces_between_special_tokens = request.spaces_between_special_tokens + sampling_params = SamplingParams( + n=request.n, + best_of=request.best_of, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + repetition_penalty=request.repetition_penalty, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + min_p=request.min_p, + stop=request.stop, + stop_token_ids=request.stop_token_ids, + ignore_eos=request.ignore_eos, + max_tokens=request.max_tokens + if not echo_without_generation else 1, + logprobs=request.logprobs, + use_beam_search=request.use_beam_search, + prompt_logprobs=request.logprobs if request.echo else None, + skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + except ValueError as e: + return self.create_error_response(str(e)) + + if use_token_ids: + result_generator = self.engine.generate(None, + sampling_params, + request_id, + prompt_token_ids=prompt) + else: + result_generator = self.engine.generate(prompt, sampling_params, + request_id, token_ids) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + def create_stream_response_json( + index: int, + text: str, + logprobs: Optional[LogProbs] = None, + finish_reason: Optional[str] = None, + usage: Optional[UsageInfo] = None, + ) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + if usage is not None: + response.usage = usage + response_json = response.json(exclude_unset=True, + ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + has_echoed = [False] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + token_ids = output.token_ids[previous_num_tokens[i]:] + if request.logprobs is not None: + top_logprobs = output.logprobs[previous_num_tokens[i]:] + else: + top_logprobs = None + offsets = len(previous_texts[i]) + if request.echo and not has_echoed[i]: + if not echo_without_generation: + delta_text = res.prompt + delta_text + token_ids = res.prompt_token_ids + token_ids + if top_logprobs: + top_logprobs = res.prompt_logprobs + top_logprobs + else: # only just return the prompt + delta_text = res.prompt + token_ids = res.prompt_token_ids + if top_logprobs: + top_logprobs = res.prompt_logprobs + has_echoed[i] = True + if request.logprobs is not None: + logprobs = self._create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=offsets, + ) + else: + logprobs = None + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason + response_json = create_stream_response_json( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + logprobs = (LogProbs() + if request.logprobs is not None else None) + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + response_json = create_stream_response_json( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + usage=final_usage, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if stream: + return completion_stream_generator() + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return self.create_error_response("Client disconnected") + final_res = res + assert final_res is not None + choices = [] + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + for output in final_res.outputs: + if request.logprobs is not None: + if not echo_without_generation: + token_ids = output.token_ids + top_logprobs = output.logprobs + if request.echo: + token_ids = prompt_token_ids + token_ids + top_logprobs = prompt_logprobs + top_logprobs + else: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + logprobs = self._create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + if not echo_without_generation: + output_text = output.text + if request.echo: + output_text = prompt_text + output_text + else: + output_text = prompt_text + choice_data = CompletionResponseChoice( + index=output.index, + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return fake_stream_generator() + + return response diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py new file mode 100644 index 000000000000..e77a0720e498 --- /dev/null +++ b/vllm/entrypoints/openai/serving_engine.py @@ -0,0 +1,130 @@ +import asyncio +from http import HTTPStatus +from typing import Dict, List, Optional, Tuple, Union +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (CompletionRequest, + ChatCompletionRequest, + ErrorResponse, LogProbs, + ModelCard, ModelList, + ModelPermission) + +logger = init_logger(__name__) + + +class OpenAIServing: + + def __init__(self, engine: AsyncLLMEngine, served_model: str): + self.engine = engine + self.served_model = served_model + + self.max_model_len = 0 + self.tokenizer = None + + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running( + ): # If the current is instanced by Ray Serve, there is already a running event loop + event_loop.create_task(self._post_init()) + else: # When using single vLLM without engine_use_ray + asyncio.run(self._post_init()) + + async def _post_init(self): + engine_model_config = await self.engine.get_model_config() + self.max_model_len = engine_model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + self.tokenizer = get_tokenizer( + engine_model_config.tokenizer, + tokenizer_mode=engine_model_config.tokenizer_mode, + trust_remote_code=engine_model_config.trust_remote_code) + + async def show_available_models(self) -> ModelList: + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=self.served_model, + root=self.served_model, + permission=[ModelPermission()]) + ] + return ModelList(data=model_cards) + + def _create_logprobs( + self, + token_ids: List[int], + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + num_output_top_logprobs: Optional[int] = None, + initial_text_offset: int = 0, + ) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + if num_output_top_logprobs: + logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is not None: + token_logprob = step_top_logprobs[token_id] + else: + token_logprob = None + token = self.tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + + last_token_len) + last_token_len = len(token) + + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + self.tokenizer.convert_ids_to_tokens(i): p + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + return logprobs + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) + + async def _check_model(self, request) -> Optional[ErrorResponse]: + if request.model == self.served_model: + return + return self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + async def _check_length( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None + ) -> Tuple[List[int], Optional[ErrorResponse]]: + assert (not (prompt is None and prompt_ids is None) + and not (prompt is not None and prompt_ids is not None) + ), "Either prompt or prompt_ids should be provided." + input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( + prompt).input_ids + token_num = len(input_ids) + + if request.max_tokens is None: + request.max_tokens = self.max_model_len - token_num + if token_num + request.max_tokens > self.max_model_len: + return input_ids, self.create_error_response( + f"This model's maximum context length is {self.max_model_len} tokens. " + f"However, you requested {request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", ) + else: + return input_ids, None diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 6482875d1c55..f1008ec8159f 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -156,20 +156,15 @@ def forward( output = out.view_as(query) else: # Decoding run. - if key_cache is not None and value_cache is not None: - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) - else: - # This happens during the initial memory profiling run for - # CUDA graphs. - output = torch.zeros_like(query) + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5190de65d795..5e1d63a6a62e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -423,7 +423,10 @@ def weight_loader(self, shard_offset = shard_offset // param.pack_factor param_data = param_data.narrow(output_dim, shard_offset, shard_size) - shard_id = tp_rank // self.num_kv_head_replicas + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py new file mode 100644 index 000000000000..3e1cfc783b8e --- /dev/null +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -0,0 +1,392 @@ +from typing import Tuple, Optional +from functools import cached_property + +import torch +import torch.nn as nn +import torch.jit + + +class RejectionSampler(nn.Module): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, strict_mode: bool = False): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self.probs_dtype = torch.float32 + self.token_id_dtype = torch.int64 + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_shape(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_inconsistent_device(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + bonus_token_ids, + draft_token_ids) + + accepted, recovered_token_ids = self._batch_modified_rejection_sampling( + target_probs, + draft_probs, + draft_token_ids, + ) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + return output_token_ids + + def _batch_modified_rejection_sampling( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Perform modified rejection sampling on each sequence. + + Returns: + A tuple of two tensors: + 0: A bool tensor of which tokens in each sequence is accepted. + shape = [batch_size, k] + 1: Token ids sampled from a recovered distribution, to be used + when a token is rejected. + shape = [batch_size, k] + """ + + batch_size, k, vocab_size = draft_probs.shape + + # shape [batch_size, k] + accepted = self._get_accepted(target_probs, draft_probs, + draft_token_ids) + + recovered_probs = self._get_recovered_probs( + target_probs, draft_probs).reshape(batch_size * k, vocab_size) + + recovered_token_ids = _multinomial(recovered_probs, + num_samples=1).reshape( + batch_size, k) + return accepted, recovered_token_ids + + def _get_accepted( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> torch.Tensor: + r"""Create bool matrix over the proposed draft tokens. If + True, then a token can be accepted, else it should be + rejected. + + Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of + :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according + to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the + same conditional probability according to the draft model, the token + is accepted with probability: + + .. math:: + \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} + {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) + + This implementation does not apply causality. When using the output, + if a token is rejected, subsequent tokens should not be used. + + Returns a bool tensor of shape [batch_size, k] specifying which tokens + are accepted. + """ + batch_size, k, _ = draft_probs.shape + batch_indices = torch.arange(batch_size, + device=target_probs.device)[:, None] + probs_indicies = torch.arange(k, device=target_probs.device) + + # shape [batch_size, k] + selected_draft_probs = draft_probs[batch_indices, probs_indicies, + draft_token_ids] + + # shape [batch_size, k] + selected_target_probs = target_probs[batch_indices, probs_indicies, + draft_token_ids] + + uniform_rand = torch.rand(batch_size, + k, + dtype=self.probs_dtype, + device=target_probs.device) + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1, ), 1, device=target_probs.device)) + accepted = uniform_rand < capped_ratio + + return accepted + + def _get_recovered_probs( + self, + target_probs: torch.Tensor, # [k, vocab_size] + draft_probs: torch.Tensor, # [k, vocab_size] + ) -> torch.Tensor: + r"""Create a probability distribution for each proposed token which can + be sampled if the proposed token is rejected. + + When this routine is applied sequentially, the true distribution of the + target model is recovered (within hardware numerics). + + The probability distribution used in this rejection case is constructed + as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of + :math:`x` given context :math:`x_1, \dots, x_n` according to the target + model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability + according to the draft model: + + .. math:: + x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ + + where :math:`(f(x))_+` is defined as: + + .. math:: + (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} + + See https://github.com/vllm-project/vllm/pull/2336 for a visualization + of the draft, target, and recovered probability distributions. + + Returns a tensor of shape [batch_size, k, vocab_size]. + + Note: This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. + """ + _, k, _ = draft_probs.shape + + # shape [batch_size, k, vocab_size] + difference = target_probs - draft_probs + + # TODO(cade): Can we use logprobs instead of probs, and avoid the + # division-by-zero errors without introducing distribution drift? + + # shape [batch_size, k, vocab_size] + f = torch.clamp(difference, min=self._smallest_positive_value) + + # shape [batch_size, k, vocab_size] + recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) + + return recovered_probs + + @cached_property + def _smallest_positive_value(self) -> float: + """Return the smallest positive value representable by the probs dtype. + This value is used when constructing a distribution from which to sample + recovered tokens in the first rejection case. + + See _get_recovered_probs for more details + + Note that this isn't actually the smallest positive value representable + by float32, but the smallest positive normal value. + See https://en.wikipedia.org/wiki/Subnormal_number for more information. + """ + return torch.finfo(self.probs_dtype).tiny + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + recovered_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via rejection sampling, all subsequent + token ids are set to -1 for the sequence. + + shape = [batch_size, k + num_bonus_tokens] + """ + bonus_token_ids = bonus_token_ids.squeeze() + batch_size, k = recovered_token_ids.shape + + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + recovered_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + assert draft_token_ids_batch_size == draft_batch_size + assert num_draft_token_ids == num_draft_probs + + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert all(probs.dtype == self.probs_dtype + for probs in [target_probs, draft_probs]) + assert all(token_ids.dtype == self.token_id_dtype + for token_ids in [bonus_token_ids, draft_token_ids]) + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead that skips the sync. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +@torch.jit.script +def _multinomial( + probs: torch.Tensor, + num_samples: int, +) -> torch.Tensor: + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs).exponential_(1.0) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index ebc9afc1be67..e8b1d3e570ff 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -76,7 +76,7 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) if do_top_p_top_k: - logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps, + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) if do_min_p: @@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def _apply_top_p_top_k( +def _apply_top_k_top_p( logits: torch.Tensor, p: torch.Tensor, k: torch.Tensor, ) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=True) + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort) - top_p_mask = probs_sum > p.unsqueeze_(dim=1) - - # Apply top-k. - # Create a mask for the top-k elements. - top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) - top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze_(dim=1) - - # Final mask. - mask = (top_p_mask | top_k_mask) - logits_sort.masked_fill_(mask, -float("inf")) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. src = torch.arange(logits_idx.shape[-1], diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f60ea640359b..b4568ae89bba 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -33,10 +33,11 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), - "PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), - "YiForCausalLM": ("yi", "YiForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "YiForCausalLM": ("yi", "YiForCausalLM") } # Models not supported by ROCm. diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi.py similarity index 76% rename from vllm/model_executor/models/phi_1_5.py rename to vllm/model_executor/models/phi.py index 9d4424dd0890..d14326196828 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi.py @@ -62,20 +62,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -class PhiEmbedding(nn.Module): - - def __init__(self, config: PretrainedConfig): - super().__init__() - - self.wte = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - - def forward(self, input_ids: torch.LongTensor): - return self.wte(input_ids) - - class PhiAttention(nn.Module): def __init__(self, @@ -93,27 +79,22 @@ def __init__(self, tensor_model_parallel_world_size) # pylint: disable=C0103 - self.Wqkv = QKVParallelLinear( - self.hidden_size, - self.head_size, - self.total_num_heads, - linear_method=linear_method, - ) self.qkv_proj = QKVParallelLinear( - config.hidden_size, + self.hidden_size, self.head_size, self.total_num_heads, - bias=False, + bias=True, linear_method=linear_method, ) - self.out_proj = RowParallelLinear( + self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, linear_method=linear_method, ) scaling = self.head_size**-0.5 - rotary_dim = config.rotary_dim + rotary_dim = int(config.partial_rotary_factor * + (config.hidden_size // config.num_attention_heads)) assert rotary_dim % 2 == 0 # pylint: disable=C0301 @@ -136,12 +117,12 @@ def forward( kv_cache: KVCache, input_metadata: InputMetadata, ) -> torch.Tensor: - qkv, _ = self.Wqkv(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.out_proj(attn_output) + output, _ = self.dense(attn_output) return output @@ -166,8 +147,7 @@ def __init__(self, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) - self.act = get_act_fn(config.activation_function, quant_config, - n_inner) + self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): hidden_states, _ = self.fc1(hidden_states) @@ -182,9 +162,9 @@ def __init__(self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() - self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.mixer = PhiAttention(config, linear_method) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.self_attn = PhiAttention(config, linear_method) self.mlp = PhiMLP(config, linear_method) def forward( @@ -195,8 +175,8 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: residual = hidden_states - hidden_states = self.ln(hidden_states) - attn_outputs = self.mixer( + hidden_states = self.input_layernorm(hidden_states) + attn_outputs = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, kv_cache=kv_cache, @@ -215,11 +195,14 @@ def __init__(self, super().__init__() self.config = config self.linear_method = linear_method - self.embd = PhiEmbedding(config) - self.h = nn.ModuleList([ + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ PhiLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) def forward( self, @@ -228,27 +211,19 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.embd(input_ids) + hidden_states = self.embed_tokens(input_ids) for i in range(self.config.num_hidden_layers): - layer = self.h[i] + layer = self.layers[i] hidden_states = layer( positions, hidden_states, kv_caches[i], input_metadata, ) - return hidden_states - -class PhiCausalLMHead(nn.Module): + hidden_states = self.final_layernorm(hidden_states) - def __init__(self, config: PretrainedConfig): - super().__init__() - self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.linear = ParallelLMHead(config.vocab_size, - config.hidden_size, - bias=True) + return hidden_states class PhiForCausalLM(nn.Module): @@ -260,8 +235,11 @@ def __init__(self, self.config = config self.linear_method = linear_method - self.transformer = PhiModel(config, linear_method) - self.lm_head = PhiCausalLMHead(config) + self.model = PhiModel(config, linear_method) + + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=True) self.sampler = Sampler(config.vocab_size) def forward( @@ -271,9 +249,9 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) - hidden_states = self.lm_head.ln(hidden_states) + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states def sample( @@ -281,7 +259,7 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - head = self.lm_head.linear + head = self.lm_head next_tokens = self.sampler(head.weight, hidden_states, sampling_metadata, head.bias) return next_tokens @@ -291,17 +269,37 @@ def load_weights(self, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v") + ] params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # pylint: disable=E1136 - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # pylint: disable=E1136 + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py new file mode 100644 index 000000000000..de5e246021b3 --- /dev/null +++ b/vllm/model_executor/models/stablelm.py @@ -0,0 +1,299 @@ +# coding=utf-8 +# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This code is based off the following work: +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json +"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class StablelmMLP(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, [config.intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=False) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class StablelmAttention(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tp_size + + self.total_num_key_value_heads = config.num_key_value_heads + if self.total_num_key_value_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_key_value_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_key_value_heads == 0 + self.num_key_value_heads = max( + 1, self.total_num_key_value_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rotary_ndims = int(self.head_dim * self.config.rope_pct) + self.scaling = self.head_dim**-0.5 + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_key_value_heads * self.head_dim + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_key_value_heads, + bias=False, + linear_method=linear_method) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + linear_method=linear_method) + self.rotary_ndims = int(self.head_dim * self.config.rope_pct) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_ndims, + max_position=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class StablelmDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.self_attn = StablelmAttention(config) + self.mlp = StablelmMLP(config, linear_method) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +class StableLMEpochModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None) -> None: + super().__init__() + # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + StablelmDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class StablelmForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = StableLMEpochModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index 9a5e2889381d..ecc94f025234 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -83,6 +83,31 @@ def initialize_model_parallel( _PIPELINE_GLOBAL_RANKS = ranks +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, + pipeline_model_parallel_size) + return + + assert ( + get_tensor_model_parallel_world_size() == tensor_model_parallel_size + ), ("tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + assert (get_pipeline_model_parallel_world_size( + ) == pipeline_model_parallel_size), ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{get_pipeline_model_parallel_world_size()=} vs. " + f"{pipeline_model_parallel_size=}") + + def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" return (_TENSOR_MODEL_PARALLEL_GROUP is not None diff --git a/vllm/utils.py b/vllm/utils.py index c32047ac27dc..874b4966b7cc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -58,7 +58,13 @@ def in_wsl() -> bool: def get_ip() -> str: - return socket.gethostbyname(socket.gethostname()) + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + + +def get_distributed_init_method(ip: str, port: int) -> str: + return f"tcp://{ip}:{port}" def get_open_port() -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index be2803089f51..03d729afaa8a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -235,9 +235,11 @@ def _prepare_decode( input_block_tables[i, :len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device="cuda") else: + max_block_table_len = max( + len(block_table) for block_table in block_tables) block_tables = _make_tensor_with_pad( block_tables, - max_len=max_context_len, + max_len=max_block_table_len, pad=0, dtype=torch.int, device="cuda", @@ -504,7 +506,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: "use '--enforce-eager' in the CLI.") logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " "If you are running out of memory, consider decreasing " - "`gpu_memory_utilization` or enforcing eager mode.") + "`gpu_memory_utilization` or enforcing eager mode. " + "You can also reduce the `max_num_seqs` as needed " + "to decrease memory usage.") start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. @@ -517,9 +521,15 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + graph_batch_size = _get_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): + for batch_size in reversed(batch_size_capture_list): # Create dummy input_metadata. input_metadata = InputMetadata( is_prompt=False, diff --git a/vllm/worker/spec_decode/multi_step_worker.py b/vllm/worker/spec_decode/multi_step_worker.py new file mode 100644 index 000000000000..591d1b1300c8 --- /dev/null +++ b/vllm/worker/spec_decode/multi_step_worker.py @@ -0,0 +1,178 @@ +from typing import List, Dict +import copy + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.worker import Worker + + +class MultiStepWorker(Worker): + """The MultiStepWorker is equivalent to a Worker except that it allows + multiple forward passes in a single call, assuming the scheduler has + allocated enough space to store the additional KV. This reduces overhead + by invoking the scheduler less. + + The MultiStepWorker does not support cache swap operations, or beam search. + Cache swap operations do not require large modifications. On the other hand, + beam search requires memory allocations during sequence forks and thus + requires more thought for MultiStepWorker support. + """ + + @torch.inference_mode() + def execute_model_multi_step( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + num_steps: int, + ) -> List[SamplerOutput]: + """Run the model forward pass num_steps times. Returns the list of + sampler output, one per model forward pass. + """ + self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, + blocks_to_swap_out, blocks_to_copy) + + # Shallow copy input data so modifications (such as appending tokens) + # do not cause side-effects. + copied_seq_group_metadata_list = self._shallow_copy_inputs( + seq_group_metadata_list) + + # Assert enough KV space for num_steps tokens per sequence. + self._assert_enough_kv_space(seq_group_metadata_list, num_steps) + + # Run model num_steps times. + model_outputs = [] + for _ in range(num_steps): + model_output = super().execute_model( + seq_group_metadata_list=copied_seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + + self._append_new_tokens(model_output, + copied_seq_group_metadata_list) + model_outputs.append(model_output) + + return model_outputs + + def _append_new_tokens( + self, model_output: SamplerOutput, + seq_group_metadata_list: SequenceGroupMetadata) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done outside of the worker, but it is + required if the worker is to perform multiple forward passes. + """ + for seq_group_metadata, sequence_group_outputs in zip( + seq_group_metadata_list, model_output): + seq_group_metadata.is_prompt = False + + for seq_output in sequence_group_outputs.samples: + # NOTE: Beam search is not supported, so we can assume that + # parent_seq_id == seq_id. + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + + seq.append_token_id(token_id, token_logprob) + + def _shallow_copy_inputs( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> List[SequenceGroupMetadata]: + """Copy input data structures to remove side-effects when input data + structures are shared with other modules. + + The multi-step worker must be able to append tokens to sequences after + a forward pass. This necessitates modification of the data structures + used by the worker. Since these data structures are shared with other + parts of vLLM, like the scheduler, we must take care not to introduce + unexpected side-effects. + + When Ray is used to orchestrate worker processes (such as when the + tensor-parallel degree is >1), this is not a problem because the input + datastructures will be serialized and created anew in the worker + process. + + However, when Ray is not used to orchestrate the worker processes (such + as when the tensor-parallel degree is 1), this is a problem. We avoid + the problem by shallow-copying the input datastructures (specifically, + the parts that will change in multiple steps). + """ + + # Shallow-copy the list of SequenceGroupMetadata. This allows us to + # append tokens and change is_prompt without external side-effects. + new_seq_group_metadata_list = [] + + for old_seq_group_metadata in seq_group_metadata_list: + # We must shallow-copy seq_group_metadata as is_prompt could change. + seq_group_metadata = copy.copy(old_seq_group_metadata) + new_seq_group_metadata_list.append(seq_group_metadata) + + # We must shallow-copy seq_data as we will append token ids + new_seq_data = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[ + seq_id].output_token_ids = old_seq_data.output_token_ids[:] + + seq_group_metadata.seq_data = new_seq_data + + return new_seq_group_metadata_list + + def _assert_enough_kv_space( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + num_steps: int) -> None: + """Assert there are enough physical blocks per sequence to store the + current KV plus additional KV from num_steps tokens. + """ + assert self.model_runner.block_size is not None + for seq_group_metadata in seq_group_metadata_list: + # Only one seq_id is guaranteed because there is no beam search. + seq_id = list(seq_group_metadata.seq_data.keys())[0] + seq = seq_group_metadata.seq_data[seq_id] + + # After num_steps, the seq len will be the current seq len + # plus one token per step. + final_seq_len = seq.get_len() + num_steps + + # We will have final_seq_len - 1 KV because vLLM saves KV for a + # token in the iteration after the token was generated. + required_num_kv_slots = final_seq_len - 1 + + # The allocated number of kv slots is the number of allocated blocks + # times the number of slots of block. + number_physical_blocks = len( + seq_group_metadata.block_tables[seq_id]) + allocated_kv_slots = (number_physical_blocks * + self.model_runner.block_size) + + if required_num_kv_slots > allocated_kv_slots: + request_id = seq_group_metadata.request_id + raise ValueError( + "The worker attempted to run " + f"{num_steps} times but found insufficient KV space for " + f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " + f"{required_num_kv_slots=}).") + + def _raise_if_unsupported( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + """MultiStepWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + raise NotImplementedError( + "MultiStepWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in seq_group_metadata_list): + raise NotImplementedError( + "MultiStepWorker does not support beam search.") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6c83f708bd9c..d1233b0a82b6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,7 +11,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( broadcast_object_list) from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel) + ensure_model_parallel_initialized) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner @@ -87,6 +87,14 @@ def profile_num_available_blocks( gpu_memory_utilization: float, cpu_swap_space: int, ) -> Tuple[int, int]: + """Profiles the peak memory usage of the model and returns the maximum + number of GPU and CPU cache blocks that can be allocated. + + Args: + block_size: The size of the cache block. + gpu_memory_utilization: The fraction of the total GPU memory to use. + cpu_swap_space: The size of the CPU swap space in bytes. + """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -218,8 +226,8 @@ def _init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - initialize_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): @@ -231,4 +239,6 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}.") + f"{compute_capability[0]}.{compute_capability[1]}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.")