Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vllm_worker support for lora_modules #3534

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/vllm_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,69 @@ See the supported models [here](https://vllm.readthedocs.io/en/latest/models/sup
'''
python3 -m fastchat.serve.vllm_worker --model-path TheBloke/vicuna-7B-v1.5-AWQ --quantization awq
'''

## Add vllm_worker support for lora_modules

### usage

1. start

```bash
export VLLM_WORKER_MULTIPROC_METHOD=spawn
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m fastchat.serve.vllm_worker \
--model-path /data/models/Qwen/Qwen2-72B-Instruct \
--tokenizer /data/models/Qwen/Qwen2-72B-Instruct \
--enable-lora \
--lora-modules m1=/data/modules/lora/adapter/m1 m2=/data/modules/lora/adapter/m2 m3=/data/modules/lora/adapter/m3 \
--model-names qwen2-72b-instruct,m1,m2,m3\
--controller http://localhost:21001 \
--host 0.0.0.0 \
--num-gpus 8 \
--port 31034 \
--limit-worker-concurrency 100 \
--worker-address http://localhost:31034
```

1. post

- example1

```bash
curl --location --request POST 'http://fastchat_address:port/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-xxx' \
--data-raw '{
"model": "m1",
"stream": false,
"temperature": 0.7,
"top_p": 0.1,
"max_tokens": 4096,
"messages": [
{
"role": "user",
"content": "Hi?"
}
]
}'
```

- example2

```bash
curl --location --request POST 'http://fastchat_address:port/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-xxx' \
--data-raw '{
"model": "qwen2-72b-instruct",
"stream": false,
"temperature": 0.7,
"top_p": 0.1,
"max_tokens": 4096,
"messages": [
{
"role": "user",
"content": "Hi?"
}
]
}'
```
58 changes: 52 additions & 6 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.cli_args import LoRAParserAction
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

Expand All @@ -24,7 +26,6 @@
)
from fastchat.utils import get_context_length, is_partial_stop


app = FastAPI()


Expand All @@ -40,6 +41,7 @@ def __init__(
no_register: bool,
llm_engine: AsyncLLMEngine,
conv_template: str,
lora_requests: LoRARequest,
):
super().__init__(
controller_addr,
Expand All @@ -55,6 +57,7 @@ def __init__(
f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..."
)
self.tokenizer = llm_engine.engine.tokenizer
self.lora_requests = lora_requests
# This is to support vllm >= 0.2.7 where TokenizerGroup was introduced
# and llm_engine.engine.tokenizer was no longer a raw tokenizer
if hasattr(self.tokenizer, "tokenizer"):
Expand All @@ -64,9 +67,24 @@ def __init__(
if not no_register:
self.init_heart_beat()

def find_lora(self, model):
lora_request = next(
(item for item in lora_requests if item.lora_name == model), None
)

if lora_request:
logger.info(f"Successfully selected LoRA adapter: {model}")
return lora_request
else:
logger.warning(
f"Corresponding LoRA not found: {model}, will perform inference without LoRA adapter."
)
return None

async def generate_stream(self, params):
self.call_ct += 1

model = params.pop("model")
context = params.pop("prompt")
request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0))
Expand Down Expand Up @@ -116,7 +134,12 @@ async def generate_stream(self, params):
frequency_penalty=frequency_penalty,
best_of=best_of,
)
results_generator = engine.generate(context, sampling_params, request_id)
lora_request = None
if self.lora_requests and len(self.lora_requests) > 0:
lora_request = self.find_lora(model)
results_generator = engine.generate(
context, sampling_params, request_id, lora_request=lora_request
)

async for request_output in results_generator:
prompt = request_output.prompt
Expand Down Expand Up @@ -156,9 +179,11 @@ async def generate_stream(self, params):
"cumulative_logprob": [
output.cumulative_logprob for output in request_output.outputs
],
"finish_reason": request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs],
"finish_reason": (
request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs]
),
}
# Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response.
# This aligns with the behavior of model_worker.
Expand Down Expand Up @@ -278,6 +303,15 @@ async def api_model_details(request: Request):
"throughput. However, if the value is too high, it may cause out-of-"
"memory (OOM) errors.",
)
parser.add_argument(
"--lora-modules",
type=nullable_str,
default=None,
nargs="+",
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.",
)

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand All @@ -286,6 +320,17 @@ async def api_model_details(request: Request):
if args.num_gpus > 1:
args.tensor_parallel_size = args.num_gpus

lora_requests = None
if args.lora_modules is not None:
lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
)
for i, lora in enumerate(args.lora_modules, start=1)
]

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
worker = VLLMWorker(
Expand All @@ -298,5 +343,6 @@ async def api_model_details(request: Request):
args.no_register,
engine,
args.conv_template,
lora_requests,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
Loading