diff --git a/ci/L0_backend_vllm/multi_lora/download.py b/ci/L0_backend_vllm/multi_lora/download.py new file mode 100644 index 00000000..3b0671c7 --- /dev/null +++ b/ci/L0_backend_vllm/multi_lora/download.py @@ -0,0 +1,47 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from huggingface_hub import snapshot_download + +if __name__ == "__main__": + # download lora weight alpaca + snapshot_download( + repo_id="swathijn/GemmaDoll-2b-dolly-LORA-Tune", + local_dir="./weights/loras/GemmaDoll", + max_workers=8, + ) + # download lora weight GemmaSheep + snapshot_download( + repo_id="eduardo-alvarez/GemmaSheep-2B-LORA-TUNED", + local_dir="./weights/loras/GemmaSheep", + max_workers=8, + ) + # download backbone weight google/gemma-2b + snapshot_download( + repo_id="unsloth/gemma-2b", + local_dir="./weights/backbone/gemma-2b", + max_workers=8, + ) diff --git a/ci/L0_backend_vllm/multi_lora/multi_lora_test.py b/ci/L0_backend_vllm/multi_lora/multi_lora_test.py new file mode 100644 index 00000000..09a163f6 --- /dev/null +++ b/ci/L0_backend_vllm/multi_lora/multi_lora_test.py @@ -0,0 +1,181 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +import unittest +from functools import partial +from typing import List + +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +sys.path.append("../../common") +from test_util import AsyncTestResultCollector, UserData, callback, create_vllm_request + +PROMPTS = ["Instruct: What do you think of Computer Science?\nOutput:"] +SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"} + +server_enable_lora = True + + +class VLLMTritonLoraTest(AsyncTestResultCollector): + def setUp(self): + self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") + self.vllm_model_name = "vllm_llama_multi_lora" + + def _test_vllm_model( + self, + prompts: List[str], + sampling_parameters, + lora_name: List[str], + server_enable_lora=True, + stream=False, + exclude_input_in_output=None, + expected_output=None, + ): + assert len(prompts) == len( + lora_name + ), "The number of prompts and lora names should be the same" + user_data = UserData() + number_of_vllm_reqs = len(prompts) + + self.triton_client.start_stream(callback=partial(callback, user_data)) + for i in range(number_of_vllm_reqs): + lora = lora_name[i] if lora_name else None + sam_para_copy = sampling_parameters.copy() + if lora is not None: + sam_para_copy["lora_name"] = lora + request_data = create_vllm_request( + prompts[i], + i, + stream, + sam_para_copy, + self.vllm_model_name, + exclude_input_in_output=exclude_input_in_output, + ) + self.triton_client.async_stream_infer( + model_name=self.vllm_model_name, + request_id=request_data["request_id"], + inputs=request_data["inputs"], + outputs=request_data["outputs"], + parameters=sampling_parameters, + ) + + for i in range(number_of_vllm_reqs): + result = user_data._completed_requests.get() + if type(result) is InferenceServerException: + print(result.message()) + if server_enable_lora: + self.assertEqual( + str(result.message()), + f"LoRA {lora_name[i]} is not supported, we currently support ['doll', 'sheep']", + "InferenceServerException", + ) + else: + self.assertEqual( + str(result.message()), + "LoRA feature is not enabled.", + "InferenceServerException", + ) + self.triton_client.stop_stream() + return + + output = result.as_numpy("text_output") + self.assertIsNotNone(output, "`text_output` should not be None") + if expected_output is not None: + self.assertEqual( + output, + expected_output[i], + 'Actual and expected outputs do not match.\n \ + Expected "{}" \n Actual:"{}"'.format( + output, expected_output[i] + ), + ) + + self.triton_client.stop_stream() + + def test_multi_lora_requests(self): + self.triton_client.load_model(self.vllm_model_name) + sampling_parameters = {"temperature": "0", "top_p": "1"} + # make two requests separately to avoid the different arrival of response answers + prompt_1 = ["Instruct: What do you think of Computer Science?\nOutput:"] + lora_1 = ["doll"] + expected_output = [ + b" I think it is a very interesting subject.\n\nInstruct: What do you" + ] + self._test_vllm_model( + prompt_1, + sampling_parameters, + lora_name=lora_1, + server_enable_lora=server_enable_lora, + stream=False, + exclude_input_in_output=True, + expected_output=expected_output, + ) + + prompt_2 = ["Instruct: Tell me more about soccer\nOutput:"] + lora_2 = ["sheep"] + expected_output = [ + b" I love soccer. I play soccer every day.\nInstruct: Tell me" + ] + self._test_vllm_model( + prompt_2, + sampling_parameters, + lora_name=lora_2, + server_enable_lora=server_enable_lora, + stream=False, + exclude_input_in_output=True, + expected_output=expected_output, + ) + self.triton_client.unload_model(self.vllm_model_name) + + def test_none_exist_lora(self): + self.triton_client.load_model(self.vllm_model_name) + prompts = [ + "Instruct: What is the capital city of France?\nOutput:", + ] + loras = ["bactrian"] + sampling_parameters = {"temperature": "0", "top_p": "1"} + self._test_vllm_model( + prompts, + sampling_parameters, + lora_name=loras, + server_enable_lora=server_enable_lora, + stream=False, + exclude_input_in_output=True, + expected_output=None, # this request will lead to lora not supported error, so there is no expected output + ) + self.triton_client.unload_model(self.vllm_model_name) + + def tearDown(self): + self.triton_client.close() + + +if __name__ == "__main__": + server_enable_lora = os.environ.get("SERVER_ENABLE_LORA", "false").lower() == "true" + + unittest.main() diff --git a/ci/L0_backend_vllm/multi_lora/test.sh b/ci/L0_backend_vllm/multi_lora/test.sh new file mode 100755 index 00000000..57544608 --- /dev/null +++ b/ci/L0_backend_vllm/multi_lora/test.sh @@ -0,0 +1,169 @@ +#!/bin/bash +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +source ../../common/util.sh + +TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} +SERVER=${TRITON_DIR}/bin/tritonserver +BACKEND_DIR=${TRITON_DIR}/backends +SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1" +SERVER_LOG="./multi_lora_server.log" +CLIENT_LOG="./multi_lora_client.log" +TEST_RESULT_FILE='test_results.txt' +CLIENT_PY="./multi_lora_test.py" +DOWNLOAD_PY="./download.py" +SAMPLE_MODELS_REPO="../../../samples/model_repository" +EXPECTED_NUM_TESTS=2 + +# first we download weights +pip install -U huggingface_hub + +rm -rf weights && mkdir -p weights/loras/GemmaDoll && mkdir -p weights/loras/GemmaSheep +mkdir -p weights/backbone/gemma-2b + +python3 $DOWNLOAD_PY -v > $CLIENT_LOG 2>&1 + +rm -rf models && mkdir -p models +cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_llama_multi_lora + +export SERVER_ENABLE_LORA=true + +model_json=$(cat < models/vllm_llama_multi_lora/1/model.json + +multi_lora_json=$(cat < models/vllm_llama_multi_lora/1/multi_lora.json + +RET=0 +# If it is the first time launching triton server with gemma-2b and multi-lora feature, +# it may take more than 1 minutes. Please wait. +SERVER_TIMEOUT=60000 + +run_server +if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 +fi + +set +e +python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1 + +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID + +# disable lora +export SERVER_ENABLE_LORA=false +model_json=$(cat < models/vllm_llama_multi_lora/1/model.json + +run_server +if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 +fi + +set +e +python3 $CLIENT_PY -v >> $CLIENT_LOG 2>&1 + +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID + +rm -rf models/ +rm -rf weights/ + +if [ $RET -eq 1 ]; then + cat $CLIENT_LOG + cat $SERVER_LOG + echo -e "\n***\n*** Multi LoRA test FAILED. \n***" +else + echo -e "\n***\n*** Multi LoRA test PASSED. \n***" +fi + +collect_artifacts_from_subdir + +exit $RET \ No newline at end of file diff --git a/ci/L0_backend_vllm/test.sh b/ci/L0_backend_vllm/test.sh index 2b68616d..5d3d4d5c 100755 --- a/ci/L0_backend_vllm/test.sh +++ b/ci/L0_backend_vllm/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,7 +26,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. RET=0 -SUBTESTS="accuracy_test request_cancellation enabled_stream vllm_backend" +SUBTESTS="accuracy_test request_cancellation enabled_stream vllm_backend multi_lora" python3 -m pip install --upgrade pip && pip3 install tritonclient[grpc] diff --git a/docs/llama_multi_lora_tutorial.md b/docs/llama_multi_lora_tutorial.md new file mode 100644 index 00000000..087e1c1c --- /dev/null +++ b/docs/llama_multi_lora_tutorial.md @@ -0,0 +1,271 @@ + + +# Tutorial on depolying multi-lora vLLM backend in Triton +The idea of multi-lora was proposed recently, for more please refer to: + ++ [S-LoRA: Serving Thousands of Concurrent LoRA Adapters](https://arxiv.org/abs/2311.03285) ++ [Punica: Multi-Tenant LoRA Serving](https://arxiv.org/abs/2310.18547) + +Now the vLLM has supported multi-lora, which integrated the `Punica` feature and related cuda kernels. See this [PR](https://github.com/vllm-project/vllm/pull/1804) for more. (2024-01-24 this PR has been merged into the main branch of vLLM) + +The following tutorial demonstrates how to deploy **a LLaMa model** with **multiple loras** on Triton Inference Server using the Triton's [Python-based](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md#python-based-backends) [vLLM](https://github.com/triton-inference-server/vllm_backend/tree/main) backend. + +> Before you continue reading, it's important to note that all command-line instructions containing `` in the document cannot be used directly by copying and pasting. +> +> `` represents the Triton version, and you must specify the Triton version you want to use for the bash command to work. + +## Step 1: Start a docker container for triton-vllm serving + +**A docker container is strongly recommended for serving**, and this tutorial will only demonstrate how to launch triton in docker env. + +First, start a docker container using the tritonserver image with vLLM backend from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags): + +```bash +# NOTICE: you must first cd to your vllm_workspace path outside the container. +mkdir vllm_workspace && cd vllm_workspace + +sudo docker run --gpus all -it --net=host -p 8001:8001 --shm-size=12G \ +--ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/vllm_workspace \ +-w /vllm_workspace nvcr.io/nvidia/tritonserver:-vllm-python-py3 \ +/bin/bash +``` + +**NOTICE:** the version of triton docker image should be configurated, here we use `` to symbolize. + +Triton's vLLM container has been introduced starting from 23.10 release, and `multi-lora` experimental support was added in vLLM v0.3.0 release. + +> Docker image version `nvcr.io/nvidia/tritonserver:24.05-vllm-python-py3` or higher version is strongly recommended. + +--- + + + +For **pre-24.05 containers**, the docker images didn't support multi-lora feature, so you need to replace that provided in the container `/opt/tritonserver/backends/vllm/model.py` with the most up to date version. Just follow this command: + +Download the `model.py` script from github: + +```bash +wget -P /opt/tritonserver/backends/vllm/ https://raw.githubusercontent.com/triton-inference-server/vllm_backend/r/src/model.py +``` + +**Notice:** `r` is the triton version you need to configure to r24.04 or later release. + +This command will download the `model.py` script to the Triton vllm backend directory which will enable multi-lora feature. + +## Step 2: Install vLLM with multi-lora feature + +We are now in the docker container, and **the following operations will be done in container environment.** + +```bash +cd /vllm_workspace +``` + +**NOTICE**: To enable multi-lora feature and speed up the inference, vLLM has integrated punica kernels. To compile the punica kernels, you need to turn the `VLLM_INSTALL_PUNICA_KERNELS` env variable on to allow punica kernels compilation. + +By default, the punica kernels will **NOT** be compiled when installing the vLLM. + +__2.1 install with pip__ + +For Triton version before 24.05, you need the following command: + +```bash +VLLM_INSTALL_PUNICA_KERNELS=1 pip install vllm==0.4.0.post1 +``` + +__2.2 build from source__ + +As alternative, you can build vLLM from source code: + +git clone vllm repository: + +```bash +git clone https://github.com/vllm-project/vllm.git +``` + +All you need to do is to follow the simple step: + +```bash +cd vllm +VLLM_INSTALL_PUNICA_KERNELS=1 pip install . +``` + +This may take you 5-10 mins. + +## Step 3: Prepare your weights + +To support multi-lora on Triton, you need to manage your file path for **model backbone** and **lora weights** separately. + +A typical weights repository can be as follows: + +``` +weights +├── backbone +│ └── llama-7b-hf +└── loras + ├── alpaca-lora-7b + └── wizardLM-lora-7b +``` + ++ A workspace for `vllm`, and `model backbone weights`, `LoRA adapter weights` is strongly recommended. ++ You should expand the storage of these weight files to ensure they are logically organized in the workspace. + +## Step 4: Prepare `model repository` for Triton Server + +__4.1 Download the model repository files__ + +To use Triton, a model repository is needed, for *model path* , *backend configuration* and other information. The vllm backend is implemented based on python backend, and `sampling_params` of vllm are sampled from `model.json`. + +To create a triton model repository, you may download the files through these commands: + +```bash +# NOTICE: you must first cd to your vllm_workspace path. +cd /vllm_workspace + +mkdir -p model_repository/vllm_model/1 +wget -P model_repository/vllm_model/1 https://raw.githubusercontent.com/triton-inference-server/vllm_backend/r/samples/model_repository/vllm_model/1/model.json +wget -P model_repository/vllm_model/ https://raw.githubusercontent.com/triton-inference-server/vllm_backend/r/samples/model_repository/vllm_model/config.pbtxt +``` + +The model repository should look like this: + +``` +model_repository/ +└── vllm_model + ├── 1 + │ └── model.json + └── config.pbtxt +``` + +--- + +Now, you have finished the basic deployment, and the file structure should look like this: + +``` +vllm_workspace +├── weights +│ ├── backbone +│ │ └── llama-7b-hf +│ └── loras +│ ├── alpaca-lora-7b +│ └── bactrian-x-llama-lora-7b +│ +└── model_repository + └── vllm_model + ├── 1 + │ └── model.json + └── config.pbtxt +``` + +__4.2 Populate `model.json`__ + +For this tutorial we will use the following set of parameters, specified in the `model.json`. + +```json +{ + "model":"/vllm_workspace/weights/backbone/llama-7b-hf", + "disable_log_requests": "true", + "gpu_memory_utilization": 0.8, + "tensor_parallel_size": 2, + "block_size": 16, + "enforce_eager": "true", + "enable_lora": "true", + "max_lora_rank": 16 +} +``` + ++ `model`: The path to your model repository ++ `disable_log_requests`: To show logs when launch vllm or not. ++ `gpu_memory_utilization`: The gpu memory allocated for the model weights and vllm *PagedAttention* kv cache manager. ++ `tensor_parallel_size`: The vllm now support the tensor paralism, so you can decide how many gpus you want to use for serving. ++ `block_size`: vLLM kv cache block size. ++ `enable_lora`: If you want to support vllm multi-lora, this should be configured and set `true`. ++ `max_lora_rank`: The maximum of LoRA rank of your lora adapter. + +The full set of parameters can be found [here](https://github.com/Yard1/vllm/blob/multi_lora/vllm/engine/arg_utils.py#L11). + +__4.3 Specify local lora path__ + +vLLM v0.4.0.post1 supported the inference of **local lora weights applying**, which means that the vllm cannot pull any lora adapter from huggingface. So triton should know where the local lora weights are. + +Create a `multi_lora.json` file under `model_repository/vllm_model/1/` path: + +```bash +cd model_repository/vllm_model/1 +touch multi_lora.json +``` + +The content of `multi_lora.json` should look like this: + +```json +{ + "alpaca": "/vllm_workspace/weights/loras/alpaca-lora-7b", + "bactrian": "/vllm_workspace/weights/loras/bactrian-x-llama-7b-lora" +} +``` + +The **key** should be the supported lora name, and the **value** should be the specific path in your machine. + +> **Warning**: if you set `enable_lora` to `true` in `model.json` without creating a `multi_lora.json` file, the server will throw `FileNotFoundError` when initializing. + +## Step 5: Launch Triton + +```bash +# NOTICE: you must first cd to your vllm_workspace path. +cd /vllm_workspace +tritonserver --model-store ./model_repository +``` + +After you start Triton you will see output on the console showing the server starting up and loading the model. When you see output like the following, Triton is ready to accept inference requests. + +``` +I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 +``` + +## Step 6: Send a request + +A client request script for multi-lora was prepared, downloading the client script from source: + +```bash +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/client.py +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/prompts.txt +``` + +Try running this script by the following command: + +```bash +python3 client.py -l +``` + +Here we assume you have prepared alpaca lora weight, thus we use: + +```bash +python3 client.py -l alpaca +``` \ No newline at end of file diff --git a/samples/client.py b/samples/client.py index 354aa36e..390a3657 100755 --- a/samples/client.py +++ b/samples/client.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -109,6 +109,8 @@ async def run(self): "max_tokens": "100", } exclude_input_in_output = self._flags.exclude_inputs_in_outputs + if self._flags.lora_name is not None: + sampling_parameters["lora_name"] = self._flags.lora_name with open(self._flags.input_prompts, "r") as file: print(f"Loading inputs from `{self._flags.input_prompts}`...") prompts = file.readlines() @@ -264,6 +266,14 @@ def create_request( default=False, help="Exclude prompt from outputs", ) + parser.add_argument( + "-l", + "--lora-name", + type=str, + required=False, + default=None, + help="The querying LoRA name", + ) FLAGS = parser.parse_args() client = LLMClient(FLAGS) diff --git a/src/model.py b/src/model.py index 1dcdecee..2d9d8ff8 100644 --- a/src/model.py +++ b/src/model.py @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -28,16 +28,18 @@ import json import os import threading -from typing import AsyncGenerator +from typing import Dict, List import numpy as np import triton_python_backend_utils as pb_utils -from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid _VLLM_ENGINE_ARGS_FILENAME = "model.json" +_MULTI_LORA_ARGS_FILENAME = "multi_lora.json" class TritonPythonModel: @@ -122,6 +124,27 @@ def initialize(self, args): self.llm_engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs(**vllm_engine_config) ) + self.enable_lora = False + + if ( + "enable_lora" in vllm_engine_config.keys() + and vllm_engine_config["enable_lora"].lower() == "true" + ): + # create Triton LoRA weights repository + multi_lora_args_filepath = os.path.join( + pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME + ) + try: + with open(multi_lora_args_filepath) as lora_file: + lora_repository: Dict[str, str] = json.load(lora_file) + self.lora_repository = lora_repository + self.supported_loras: List[str] = list(self.lora_repository.keys()) + self.supported_loras_len = len(self.supported_loras) + self.enable_lora = True + except FileNotFoundError: + raise FileNotFoundError( + f"Triton backend cannot find {multi_lora_args_filepath}." + ) output_config = pb_utils.get_output_config_by_name( self.model_config, "text_output" @@ -296,12 +319,19 @@ async def generate(self, request): parameters = request.parameters() sampling_params_dict = self.get_sampling_params_dict(parameters) + lora_name = sampling_params_dict.pop("lora_name", None) sampling_params = SamplingParams(**sampling_params_dict) - last_output = None prev_outputs = None + lora_request = None + if lora_name is not None: + lora_id = str(self.supported_loras.index(lora_name) + 1) + lora_int_id = int(lora_id) + lora_local_path = self.lora_repository[lora_name] + lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) + async for output in self.llm_engine.generate( - prompt, sampling_params, request_id + prompt, sampling_params, request_id, lora_request=lora_request ): if response_sender.is_cancelled(): self.logger.log_info("[vllm] Cancelling the request") @@ -350,6 +380,49 @@ async def generate(self, request): finally: self.ongoing_request_count -= 1 + def verify_loras(self, request): + # We will check if the requested lora exists here, if not we will send a + # response with `LoRA not found` information. In this way we may avoid + # further processing. + verified_request = None + lora_error = None + lora_name = None + parameters_input_tensor = pb_utils.get_input_tensor_by_name( + request, "sampling_parameters" + ) + if parameters_input_tensor: + parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") + sampling_params_dict = self.get_sampling_params_dict(parameters) + lora_name = sampling_params_dict.pop("lora_name", None) + + if lora_name is not None: + if not self.enable_lora: + lora_error = pb_utils.TritonError("LoRA feature is not enabled.") + self.logger.log_info( + "[vllm] LoRA is not enabled, please restart the backend with LoRA enabled." + ) + elif lora_name not in self.supported_loras: + lora_error = pb_utils.TritonError( + f"LoRA {lora_name} is not supported, we currently support {self.supported_loras}" + ) + self.logger.log_info(f"[vllm] LoRA {lora_name} not found.") + + if lora_error is not None: + output_tensor = pb_utils.Tensor( + "text_output", + np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype), + ) + response = pb_utils.InferenceResponse( + output_tensors=[output_tensor], error=lora_error + ) + response_sender = request.get_response_sender() + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + else: + verified_request = request + return verified_request + def execute(self, requests): """ Triton core issues requests to the backend via this method. @@ -361,7 +434,9 @@ def execute(self, requests): We are pushing all the requests on vllm and let it handle the full traffic. """ for request in requests: - self.create_task(self.generate(request)) + request = self.verify_loras(request) + if request is not None: + self.create_task(self.generate(request)) return None def finalize(self):