Skip to content

Commit 674da1d

Browse files
authored
chat/completions endpoint (#121)
* Initial implementation of chat/completions endpoint and its streaming variant * Reusing datatypes from the openai entrypoints * Response role from arg * Added models endpoint and model validation from the request
1 parent dd1a208 commit 674da1d

File tree

3 files changed

+215
-181
lines changed

3 files changed

+215
-181
lines changed

vllm/entrypoints/fast_sync_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
self.result_queue = result_queue
3737
self.finish = False
3838
self.need_restart = False
39+
self.llm_engine: LLMEngine
3940

4041
def _add_request(
4142
self,

vllm/entrypoints/sync_openai/api_server.py

Lines changed: 214 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,33 @@
44
import threading
55
import time
66
from contextlib import asynccontextmanager
7-
from typing import Dict
7+
from http import HTTPStatus
8+
from typing import Dict, Iterable, List, Union, cast
89

910
import uvicorn
1011
from fastapi import FastAPI, Request
1112
from fastapi.middleware.cors import CORSMiddleware
12-
from fastapi.responses import StreamingResponse
13+
from fastapi.responses import JSONResponse, StreamingResponse
1314
from fastapi.routing import Mount
15+
from openai.types.chat import ChatCompletionContentPartTextParam
1416
from prometheus_client import make_asgi_app
1517

18+
import vllm
1619
from vllm import FastSyncLLM as LLM
1720
from vllm import envs
1821
from vllm.engine.arg_utils import EngineArgs
1922
from vllm.entrypoints.openai.cli_args import make_arg_parser
20-
from vllm.entrypoints.sync_openai.protocol import (CompletionRequest,
21-
CompletionResponse,
22-
CompletionResponseChoice,
23-
UsageInfo)
23+
from vllm.entrypoints.openai.protocol import (
24+
ChatCompletionContentPartParam, ChatCompletionMessageParam,
25+
ChatCompletionRequest, ChatCompletionResponse,
26+
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
27+
ChatCompletionStreamResponse, ChatMessage, CompletionRequest,
28+
CompletionResponse, CompletionResponseChoice, DeltaMessage, ErrorResponse,
29+
ModelCard, ModelList, ModelPermission, UsageInfo)
30+
from vllm.entrypoints.openai.serving_chat import (ChatMessageParseResult,
31+
ConversationMessage)
2432
from vllm.logger import init_logger
33+
from vllm.transformers_utils.tokenizer import get_tokenizer
2534
from vllm.utils import random_uuid
2635

2736
mp = multiprocessing.get_context(envs.VLLM_WORKER_MULTIPROC_METHOD)
@@ -41,14 +50,19 @@ class BackgroundRunner:
4150

4251
def __init__(self):
4352
self.value = 0
44-
self.engine_args = None
53+
self.engine_args: EngineArgs
4554
self.input_queue: multiprocessing.Queue = mp.Queue()
4655
self.result_queue: multiprocessing.Queue = mp.Queue()
4756
self.result_queues: Dict[str, asyncio.Queue] = {}
4857
self.t: threading.Thread = threading.Thread(target=self.thread_proc)
4958
self.loop = None
5059
self.llm: LLM
5160
self.proc: multiprocessing.Process
61+
self.tokenizer = None
62+
self.response_role: str
63+
64+
def set_response_role(self, role):
65+
self.response_role = role
5266

5367
def set_engine_args(self, engine_args):
5468
self.engine_args = engine_args
@@ -75,6 +89,7 @@ async def run_main(self):
7589
input_queue=self.input_queue,
7690
result_queue=self.result_queue,
7791
)
92+
7893
self.loop = asyncio.get_event_loop()
7994
self.proc = mp.Process(target=self.llm.run_engine)
8095
self.t.start()
@@ -103,6 +118,15 @@ async def lifespan(app: FastAPI):
103118
asyncio.create_task(runner.run_main())
104119
await runner.result_queues["Ready"].get()
105120
del runner.result_queues["Ready"]
121+
122+
tokenizer = get_tokenizer(
123+
engine_args.tokenizer,
124+
tokenizer_mode=engine_args.tokenizer_mode,
125+
tokenizer_revision=engine_args.tokenizer_revision,
126+
trust_remote_code=engine_args.trust_remote_code,
127+
truncation_side="left")
128+
runner.tokenizer = tokenizer
129+
106130
yield
107131

108132

@@ -115,6 +139,33 @@ async def lifespan(app: FastAPI):
115139
app.routes.append(route)
116140

117141

142+
@app.get("/v1/models")
143+
async def show_available_models():
144+
models = [
145+
ModelCard(id=runner.engine_args.model,
146+
root=runner.engine_args.model,
147+
permission=[ModelPermission()])
148+
]
149+
model_list = ModelList(data=models)
150+
return JSONResponse(content=model_list.model_dump())
151+
152+
153+
@app.get("/version")
154+
async def show_version():
155+
ver = {"version": vllm.__version__}
156+
return JSONResponse(content=ver)
157+
158+
159+
async def _check_model(request: Union[CompletionRequest,
160+
ChatCompletionRequest]):
161+
model = request.model
162+
if model != runner.engine_args.model:
163+
return ErrorResponse(message=f"The model {model} does not exist.",
164+
type="NotFoundError",
165+
code=HTTPStatus.NOT_FOUND)
166+
return None
167+
168+
118169
async def completion_generator(model, result_queue, choices, created_time,
119170
ids):
120171
completed = 0
@@ -139,8 +190,9 @@ async def completion_generator(model, result_queue, choices, created_time,
139190
res.usage = UsageInfo()
140191
res.usage.completion_tokens = stats.get("tokens", 0)
141192
res.usage.prompt_tokens = stats.get("prompt", 0)
142-
res.usage.total_tokens = (res.usage.completion_tokens +
143-
res.usage.prompt_tokens)
193+
res.usage.total_tokens = (
194+
res.usage.completion_tokens + # type: ignore
195+
res.usage.prompt_tokens)
144196
res.choices[0].finish_reason = stats["finish_reason"]
145197
res.choices[0].stop_reason = stats["stop_reason"]
146198
completed += 1
@@ -158,6 +210,10 @@ async def completion_generator(model, result_queue, choices, created_time,
158210

159211
@app.post("/v1/completions")
160212
async def completions(request: CompletionRequest, raw_request: Request):
213+
error_check_ret = await _check_model(request)
214+
if error_check_ret is not None:
215+
return JSONResponse(content=error_check_ret.model_dump(),
216+
status_code=error_check_ret.code)
161217
sampling_params = request.to_sampling_params()
162218
ids, result_queue = await runner.add_request(request.prompt,
163219
sampling_params)
@@ -179,8 +235,7 @@ async def completions(request: CompletionRequest, raw_request: Request):
179235
created_time = int(time.time())
180236
return StreamingResponse(content=completion_generator(
181237
request.model, result_queue, choices, created_time, ids),
182-
media_type="text/event-stream",
183-
headers={"Access-Control-Allow-Origin": "*"})
238+
media_type="text/event-stream")
184239
while True:
185240
request_id, token, stats = await result_queue.get()
186241
choice_idx = choices[request_id]
@@ -200,6 +255,153 @@ async def completions(request: CompletionRequest, raw_request: Request):
200255
return res
201256

202257

258+
def parse_chat_message_content_parts(
259+
role: str,
260+
parts: Iterable[ChatCompletionContentPartParam],
261+
) -> ChatMessageParseResult:
262+
texts: List[str] = []
263+
264+
for _, part in enumerate(parts):
265+
part_type = part["type"]
266+
if part_type == "text":
267+
text = cast(ChatCompletionContentPartTextParam, part)["text"]
268+
269+
texts.append(text)
270+
else:
271+
raise NotImplementedError(f"Unknown part type: {part_type}")
272+
273+
messages = [ConversationMessage(role=role, content="\n".join(texts))]
274+
275+
return ChatMessageParseResult(messages=messages)
276+
277+
278+
def parse_chat_message_content(
279+
message: ChatCompletionMessageParam, ) -> ChatMessageParseResult:
280+
role = message["role"]
281+
content = message.get("content")
282+
283+
if content is None:
284+
return ChatMessageParseResult(messages=[])
285+
if isinstance(content, str):
286+
messages = [ConversationMessage(role=role, content=content)]
287+
return ChatMessageParseResult(messages=messages)
288+
289+
return parse_chat_message_content_parts(role, content)
290+
291+
292+
async def chat_completion_generator(model, result_queue, created_time, id):
293+
try:
294+
first_token = ChatCompletionStreamResponse(
295+
id=id,
296+
created=created_time,
297+
model=model,
298+
choices=[
299+
ChatCompletionResponseStreamChoice(
300+
index=0,
301+
delta=DeltaMessage(role=runner.response_role),
302+
logprobs=None,
303+
finish_reason=None,
304+
stop_reason=None)
305+
],
306+
usage=None)
307+
response_json = first_token.model_dump_json(exclude_unset=True)
308+
yield f"data: {response_json}\n\n"
309+
310+
while True:
311+
request_id, token, stats = await result_queue.get()
312+
assert request_id == id
313+
314+
res = ChatCompletionStreamResponse(
315+
id=request_id,
316+
created=created_time,
317+
model=model,
318+
choices=[
319+
ChatCompletionResponseStreamChoice(
320+
index=0,
321+
delta=DeltaMessage(content=token),
322+
logprobs=None,
323+
finish_reason=None,
324+
stop_reason=None)
325+
],
326+
usage=None)
327+
if stats is not None:
328+
res.usage = UsageInfo()
329+
res.usage.completion_tokens = stats.get("tokens", 0)
330+
res.usage.prompt_tokens = stats.get("prompt", 0)
331+
res.usage.total_tokens = (
332+
res.usage.completion_tokens + # type: ignore
333+
res.usage.prompt_tokens)
334+
res.choices[0].finish_reason = stats["finish_reason"]
335+
res.choices[0].stop_reason = stats["stop_reason"]
336+
response_json = res.model_dump_json(exclude_unset=True)
337+
yield f"data: {response_json}\n\n"
338+
if stats is not None:
339+
runner.remove_result_queues([id])
340+
break
341+
342+
yield "data: [DONE]\n\n"
343+
except Exception as e:
344+
logger.error("Error in completion_generator: %s", e)
345+
return
346+
347+
348+
@app.post("/v1/chat/completions")
349+
async def chat_completions(request: ChatCompletionRequest,
350+
raw_request: Request):
351+
error_check_ret = await _check_model(request)
352+
if error_check_ret is not None:
353+
return JSONResponse(content=error_check_ret.model_dump(),
354+
status_code=error_check_ret.code)
355+
sampling_params = request.to_sampling_params()
356+
conversation: List[ConversationMessage] = []
357+
358+
res = ChatCompletionResponse(model=request.model,
359+
choices=[],
360+
usage=UsageInfo(prompt_tokens=0,
361+
total_tokens=0,
362+
completion_tokens=0))
363+
364+
for msg in request.messages:
365+
parsed_msg = parse_chat_message_content(msg)
366+
conversation.extend(parsed_msg.messages)
367+
368+
prompt = runner.tokenizer.apply_chat_template( # type: ignore
369+
conversation=conversation,
370+
tokenize=False,
371+
add_generation_prompt=request.add_generation_prompt,
372+
)
373+
374+
ids, result_queue = await runner.add_request(prompt, sampling_params)
375+
assert len(ids) == 1
376+
377+
if request.stream:
378+
created_time = int(time.time())
379+
return StreamingResponse(content=chat_completion_generator(
380+
request.model, result_queue, created_time, ids[0]),
381+
media_type="text/event-stream")
382+
383+
res.choices.append(
384+
ChatCompletionResponseChoice(
385+
index=0,
386+
message=ChatMessage(role=runner.response_role, content=""),
387+
finish_reason=None,
388+
stop_reason=None))
389+
390+
while True:
391+
_, token, stats = await result_queue.get()
392+
res.choices[0].message.content += str(token)
393+
if stats is not None:
394+
res.usage.completion_tokens += stats["tokens"] # type: ignore
395+
res.usage.prompt_tokens += stats["prompt"] # type: ignore
396+
res.choices[0].finish_reason = stats["finish_reason"]
397+
res.choices[0].stop_reason = stats["stop_reason"]
398+
runner.remove_result_queues(ids)
399+
break
400+
res.usage.total_tokens = ( # type: ignore
401+
res.usage.completion_tokens + res.usage.prompt_tokens) # type: ignore
402+
return res
403+
404+
203405
def parse_args():
204406
parser = make_arg_parser()
205407
return parser.parse_args()
@@ -209,6 +411,7 @@ def parse_args():
209411
args = parse_args()
210412
engine_args = EngineArgs.from_cli_args(args)
211413
runner.set_engine_args(engine_args)
414+
runner.set_response_role(args.response_role)
212415

213416
app.add_middleware(
214417
CORSMiddleware,

0 commit comments

Comments
 (0)