44import threading
55import time
66from contextlib import asynccontextmanager
7- from typing import Dict
7+ from http import HTTPStatus
8+ from typing import Dict , Iterable , List , Union , cast
89
910import uvicorn
1011from fastapi import FastAPI , Request
1112from fastapi .middleware .cors import CORSMiddleware
12- from fastapi .responses import StreamingResponse
13+ from fastapi .responses import JSONResponse , StreamingResponse
1314from fastapi .routing import Mount
15+ from openai .types .chat import ChatCompletionContentPartTextParam
1416from prometheus_client import make_asgi_app
1517
18+ import vllm
1619from vllm import FastSyncLLM as LLM
1720from vllm import envs
1821from vllm .engine .arg_utils import EngineArgs
1922from vllm .entrypoints .openai .cli_args import make_arg_parser
20- from vllm .entrypoints .sync_openai .protocol import (CompletionRequest ,
21- CompletionResponse ,
22- CompletionResponseChoice ,
23- UsageInfo )
23+ from vllm .entrypoints .openai .protocol import (
24+ ChatCompletionContentPartParam , ChatCompletionMessageParam ,
25+ ChatCompletionRequest , ChatCompletionResponse ,
26+ ChatCompletionResponseChoice , ChatCompletionResponseStreamChoice ,
27+ ChatCompletionStreamResponse , ChatMessage , CompletionRequest ,
28+ CompletionResponse , CompletionResponseChoice , DeltaMessage , ErrorResponse ,
29+ ModelCard , ModelList , ModelPermission , UsageInfo )
30+ from vllm .entrypoints .openai .serving_chat import (ChatMessageParseResult ,
31+ ConversationMessage )
2432from vllm .logger import init_logger
33+ from vllm .transformers_utils .tokenizer import get_tokenizer
2534from vllm .utils import random_uuid
2635
2736mp = multiprocessing .get_context (envs .VLLM_WORKER_MULTIPROC_METHOD )
@@ -41,14 +50,19 @@ class BackgroundRunner:
4150
4251 def __init__ (self ):
4352 self .value = 0
44- self .engine_args = None
53+ self .engine_args : EngineArgs
4554 self .input_queue : multiprocessing .Queue = mp .Queue ()
4655 self .result_queue : multiprocessing .Queue = mp .Queue ()
4756 self .result_queues : Dict [str , asyncio .Queue ] = {}
4857 self .t : threading .Thread = threading .Thread (target = self .thread_proc )
4958 self .loop = None
5059 self .llm : LLM
5160 self .proc : multiprocessing .Process
61+ self .tokenizer = None
62+ self .response_role : str
63+
64+ def set_response_role (self , role ):
65+ self .response_role = role
5266
5367 def set_engine_args (self , engine_args ):
5468 self .engine_args = engine_args
@@ -75,6 +89,7 @@ async def run_main(self):
7589 input_queue = self .input_queue ,
7690 result_queue = self .result_queue ,
7791 )
92+
7893 self .loop = asyncio .get_event_loop ()
7994 self .proc = mp .Process (target = self .llm .run_engine )
8095 self .t .start ()
@@ -103,6 +118,15 @@ async def lifespan(app: FastAPI):
103118 asyncio .create_task (runner .run_main ())
104119 await runner .result_queues ["Ready" ].get ()
105120 del runner .result_queues ["Ready" ]
121+
122+ tokenizer = get_tokenizer (
123+ engine_args .tokenizer ,
124+ tokenizer_mode = engine_args .tokenizer_mode ,
125+ tokenizer_revision = engine_args .tokenizer_revision ,
126+ trust_remote_code = engine_args .trust_remote_code ,
127+ truncation_side = "left" )
128+ runner .tokenizer = tokenizer
129+
106130 yield
107131
108132
@@ -115,6 +139,33 @@ async def lifespan(app: FastAPI):
115139app .routes .append (route )
116140
117141
142+ @app .get ("/v1/models" )
143+ async def show_available_models ():
144+ models = [
145+ ModelCard (id = runner .engine_args .model ,
146+ root = runner .engine_args .model ,
147+ permission = [ModelPermission ()])
148+ ]
149+ model_list = ModelList (data = models )
150+ return JSONResponse (content = model_list .model_dump ())
151+
152+
153+ @app .get ("/version" )
154+ async def show_version ():
155+ ver = {"version" : vllm .__version__ }
156+ return JSONResponse (content = ver )
157+
158+
159+ async def _check_model (request : Union [CompletionRequest ,
160+ ChatCompletionRequest ]):
161+ model = request .model
162+ if model != runner .engine_args .model :
163+ return ErrorResponse (message = f"The model { model } does not exist." ,
164+ type = "NotFoundError" ,
165+ code = HTTPStatus .NOT_FOUND )
166+ return None
167+
168+
118169async def completion_generator (model , result_queue , choices , created_time ,
119170 ids ):
120171 completed = 0
@@ -139,8 +190,9 @@ async def completion_generator(model, result_queue, choices, created_time,
139190 res .usage = UsageInfo ()
140191 res .usage .completion_tokens = stats .get ("tokens" , 0 )
141192 res .usage .prompt_tokens = stats .get ("prompt" , 0 )
142- res .usage .total_tokens = (res .usage .completion_tokens +
143- res .usage .prompt_tokens )
193+ res .usage .total_tokens = (
194+ res .usage .completion_tokens + # type: ignore
195+ res .usage .prompt_tokens )
144196 res .choices [0 ].finish_reason = stats ["finish_reason" ]
145197 res .choices [0 ].stop_reason = stats ["stop_reason" ]
146198 completed += 1
@@ -158,6 +210,10 @@ async def completion_generator(model, result_queue, choices, created_time,
158210
159211@app .post ("/v1/completions" )
160212async def completions (request : CompletionRequest , raw_request : Request ):
213+ error_check_ret = await _check_model (request )
214+ if error_check_ret is not None :
215+ return JSONResponse (content = error_check_ret .model_dump (),
216+ status_code = error_check_ret .code )
161217 sampling_params = request .to_sampling_params ()
162218 ids , result_queue = await runner .add_request (request .prompt ,
163219 sampling_params )
@@ -179,8 +235,7 @@ async def completions(request: CompletionRequest, raw_request: Request):
179235 created_time = int (time .time ())
180236 return StreamingResponse (content = completion_generator (
181237 request .model , result_queue , choices , created_time , ids ),
182- media_type = "text/event-stream" ,
183- headers = {"Access-Control-Allow-Origin" : "*" })
238+ media_type = "text/event-stream" )
184239 while True :
185240 request_id , token , stats = await result_queue .get ()
186241 choice_idx = choices [request_id ]
@@ -200,6 +255,153 @@ async def completions(request: CompletionRequest, raw_request: Request):
200255 return res
201256
202257
258+ def parse_chat_message_content_parts (
259+ role : str ,
260+ parts : Iterable [ChatCompletionContentPartParam ],
261+ ) -> ChatMessageParseResult :
262+ texts : List [str ] = []
263+
264+ for _ , part in enumerate (parts ):
265+ part_type = part ["type" ]
266+ if part_type == "text" :
267+ text = cast (ChatCompletionContentPartTextParam , part )["text" ]
268+
269+ texts .append (text )
270+ else :
271+ raise NotImplementedError (f"Unknown part type: { part_type } " )
272+
273+ messages = [ConversationMessage (role = role , content = "\n " .join (texts ))]
274+
275+ return ChatMessageParseResult (messages = messages )
276+
277+
278+ def parse_chat_message_content (
279+ message : ChatCompletionMessageParam , ) -> ChatMessageParseResult :
280+ role = message ["role" ]
281+ content = message .get ("content" )
282+
283+ if content is None :
284+ return ChatMessageParseResult (messages = [])
285+ if isinstance (content , str ):
286+ messages = [ConversationMessage (role = role , content = content )]
287+ return ChatMessageParseResult (messages = messages )
288+
289+ return parse_chat_message_content_parts (role , content )
290+
291+
292+ async def chat_completion_generator (model , result_queue , created_time , id ):
293+ try :
294+ first_token = ChatCompletionStreamResponse (
295+ id = id ,
296+ created = created_time ,
297+ model = model ,
298+ choices = [
299+ ChatCompletionResponseStreamChoice (
300+ index = 0 ,
301+ delta = DeltaMessage (role = runner .response_role ),
302+ logprobs = None ,
303+ finish_reason = None ,
304+ stop_reason = None )
305+ ],
306+ usage = None )
307+ response_json = first_token .model_dump_json (exclude_unset = True )
308+ yield f"data: { response_json } \n \n "
309+
310+ while True :
311+ request_id , token , stats = await result_queue .get ()
312+ assert request_id == id
313+
314+ res = ChatCompletionStreamResponse (
315+ id = request_id ,
316+ created = created_time ,
317+ model = model ,
318+ choices = [
319+ ChatCompletionResponseStreamChoice (
320+ index = 0 ,
321+ delta = DeltaMessage (content = token ),
322+ logprobs = None ,
323+ finish_reason = None ,
324+ stop_reason = None )
325+ ],
326+ usage = None )
327+ if stats is not None :
328+ res .usage = UsageInfo ()
329+ res .usage .completion_tokens = stats .get ("tokens" , 0 )
330+ res .usage .prompt_tokens = stats .get ("prompt" , 0 )
331+ res .usage .total_tokens = (
332+ res .usage .completion_tokens + # type: ignore
333+ res .usage .prompt_tokens )
334+ res .choices [0 ].finish_reason = stats ["finish_reason" ]
335+ res .choices [0 ].stop_reason = stats ["stop_reason" ]
336+ response_json = res .model_dump_json (exclude_unset = True )
337+ yield f"data: { response_json } \n \n "
338+ if stats is not None :
339+ runner .remove_result_queues ([id ])
340+ break
341+
342+ yield "data: [DONE]\n \n "
343+ except Exception as e :
344+ logger .error ("Error in completion_generator: %s" , e )
345+ return
346+
347+
348+ @app .post ("/v1/chat/completions" )
349+ async def chat_completions (request : ChatCompletionRequest ,
350+ raw_request : Request ):
351+ error_check_ret = await _check_model (request )
352+ if error_check_ret is not None :
353+ return JSONResponse (content = error_check_ret .model_dump (),
354+ status_code = error_check_ret .code )
355+ sampling_params = request .to_sampling_params ()
356+ conversation : List [ConversationMessage ] = []
357+
358+ res = ChatCompletionResponse (model = request .model ,
359+ choices = [],
360+ usage = UsageInfo (prompt_tokens = 0 ,
361+ total_tokens = 0 ,
362+ completion_tokens = 0 ))
363+
364+ for msg in request .messages :
365+ parsed_msg = parse_chat_message_content (msg )
366+ conversation .extend (parsed_msg .messages )
367+
368+ prompt = runner .tokenizer .apply_chat_template ( # type: ignore
369+ conversation = conversation ,
370+ tokenize = False ,
371+ add_generation_prompt = request .add_generation_prompt ,
372+ )
373+
374+ ids , result_queue = await runner .add_request (prompt , sampling_params )
375+ assert len (ids ) == 1
376+
377+ if request .stream :
378+ created_time = int (time .time ())
379+ return StreamingResponse (content = chat_completion_generator (
380+ request .model , result_queue , created_time , ids [0 ]),
381+ media_type = "text/event-stream" )
382+
383+ res .choices .append (
384+ ChatCompletionResponseChoice (
385+ index = 0 ,
386+ message = ChatMessage (role = runner .response_role , content = "" ),
387+ finish_reason = None ,
388+ stop_reason = None ))
389+
390+ while True :
391+ _ , token , stats = await result_queue .get ()
392+ res .choices [0 ].message .content += str (token )
393+ if stats is not None :
394+ res .usage .completion_tokens += stats ["tokens" ] # type: ignore
395+ res .usage .prompt_tokens += stats ["prompt" ] # type: ignore
396+ res .choices [0 ].finish_reason = stats ["finish_reason" ]
397+ res .choices [0 ].stop_reason = stats ["stop_reason" ]
398+ runner .remove_result_queues (ids )
399+ break
400+ res .usage .total_tokens = ( # type: ignore
401+ res .usage .completion_tokens + res .usage .prompt_tokens ) # type: ignore
402+ return res
403+
404+
203405def parse_args ():
204406 parser = make_arg_parser ()
205407 return parser .parse_args ()
@@ -209,6 +411,7 @@ def parse_args():
209411 args = parse_args ()
210412 engine_args = EngineArgs .from_cli_args (args )
211413 runner .set_engine_args (engine_args )
414+ runner .set_response_role (args .response_role )
212415
213416 app .add_middleware (
214417 CORSMiddleware ,
0 commit comments