1- from chat_module import LLMChatModule , supported_models , quantization_keys
1+ from . chat_module import ChatModule , supported_models , quantization_keys
22
33from pydantic import BaseModel
44from fastapi import FastAPI , HTTPException
1515
1616session = {}
1717
18+
1819@asynccontextmanager
1920async def lifespan (app : FastAPI ):
2021
2122 ARGS = _parse_args ()
2223
23- chat_mod = LLMChatModule (
24- ARGS .mlc_lib_path ,
25- ARGS .device_name ,
26- ARGS .device_id
27- )
28- model_path = os .path .join (
29- ARGS .artifact_path ,
30- ARGS .model + "-" + ARGS .quantization
31- )
24+ chat_mod = ChatModule (ARGS .mlc_lib_path , ARGS .device_name , ARGS .device_id )
25+ model_path = os .path .join (ARGS .artifact_path , ARGS .model + "-" + ARGS .quantization )
3226 model_dir = ARGS .model + "-" + ARGS .quantization
3327 model_lib = model_dir + "-" + ARGS .device_name + ".so"
3428 lib_dir = os .path .join (model_path , model_lib )
@@ -38,16 +32,22 @@ async def lifespan(app: FastAPI):
3832 elif os .path .exists (prebuilt_lib_dir ):
3933 lib = tvm .runtime .load_module (prebuilt_lib_dir )
4034 else :
41- raise ValueError (f"Unable to find { model_lib } at { lib_dir } or { prebuilt_lib_dir } ." )
35+ raise ValueError (
36+ f"Unable to find { model_lib } at { lib_dir } or { prebuilt_lib_dir } ."
37+ )
4238
4339 local_model_path = os .path .join (model_path , "params" )
44- prebuilt_model_path = os .path .join (ARGS .artifact_path , "prebuilt" , f"mlc-chat-{ model_dir } " )
40+ prebuilt_model_path = os .path .join (
41+ ARGS .artifact_path , "prebuilt" , f"mlc-chat-{ model_dir } "
42+ )
4543 if os .path .exists (local_model_path ):
4644 chat_mod .reload (lib = lib , model_path = local_model_path )
4745 elif os .path .exists (prebuilt_model_path ):
4846 chat_mod .reload (lib = lib , model_path = prebuilt_model_path )
4947 else :
50- raise ValueError (f"Unable to find model params at { local_model_path } or { prebuilt_model_path } ." )
48+ raise ValueError (
49+ f"Unable to find model params at { local_model_path } or { prebuilt_model_path } ."
50+ )
5151 session ["chat_mod" ] = chat_mod
5252
5353 yield
@@ -57,13 +57,11 @@ async def lifespan(app: FastAPI):
5757
5858app = FastAPI (lifespan = lifespan )
5959
60+
6061def _parse_args ():
6162 args = argparse .ArgumentParser ()
6263 args .add_argument (
63- "--model" ,
64- type = str ,
65- choices = supported_models (),
66- default = "vicuna-v1-7b"
64+ "--model" , type = str , choices = supported_models (), default = "vicuna-v1-7b"
6765 )
6866 args .add_argument ("--artifact-path" , type = str , default = "dist" )
6967 args .add_argument (
@@ -85,65 +83,74 @@ def _parse_args():
8583"""
8684List the currently supported models and provides basic information about each of them.
8785"""
86+
87+
8888@app .get ("/models" )
8989async def read_models ():
90- return {
91- "data" : [{
92- "id" : model ,
93- "object" :"model"
94- } for model in supported_models ()]
95- }
90+ return {"data" : [{"id" : model , "object" : "model" } for model in supported_models ()]}
91+
9692
9793"""
9894Retrieve a model instance with basic information about the model.
9995"""
96+
97+
10098@app .get ("/models/{model}" )
10199async def read_model (model : str ):
102100 if model not in supported_models ():
103101 raise HTTPException (status_code = 404 , detail = f"Model { model } is not supported." )
104- return {
105- "id" : model ,
106- "object" :"model"
107- }
102+ return {"id" : model , "object" : "model" }
103+
108104
109105class ChatRequest (BaseModel ):
110106 prompt : str
111107 stream : bool = False
112108
109+
113110"""
114111Creates model response for the given chat conversation.
115112"""
113+
114+
116115@app .post ("/chat/completions" )
117116def request_completion (request : ChatRequest ):
118117 session ["chat_mod" ].prefill (input = request .prompt )
119118 if request .stream :
119+
120120 def iter_response ():
121121 while not session ["chat_mod" ].stopped ():
122122 session ["chat_mod" ].decode ()
123123 msg = session ["chat_mod" ].get_message ()
124124 yield json .dumps ({"message" : msg })
125- return StreamingResponse (iter_response (), media_type = 'application/json' )
125+
126+ return StreamingResponse (iter_response (), media_type = "application/json" )
126127 else :
127128 msg = None
128129 while not session ["chat_mod" ].stopped ():
129130 session ["chat_mod" ].decode ()
130131 msg = session ["chat_mod" ].get_message ()
131132 return {"message" : msg }
132133
134+
133135"""
134136Reset the chat for the currently initialized model.
135137"""
138+
139+
136140@app .post ("/chat/reset" )
137141def reset ():
138142 session ["chat_mod" ].reset_chat ()
139143
144+
140145"""
141146Get the runtime stats.
142147"""
148+
149+
143150@app .get ("/stats" )
144151def read_stats ():
145152 return session ["chat_mod" ].runtime_stats_text ()
146153
147154
148155if __name__ == "__main__" :
149- uvicorn .run ("server:app" , port = 8000 , reload = True , access_log = False )
156+ uvicorn .run ("mlc_chat. server:app" , port = 8000 , reload = True , access_log = False )
0 commit comments