9
9
from dataclasses import asdict
10
10
from typing import Dict , List , Union
11
11
12
- from api .api import AssistantMessage , CompletionRequest , OpenAiApiGenerator , UserMessage
12
+ from api .api import CompletionRequest , OpenAiApiGenerator
13
+ from api .models import get_model_info_list , retrieve_model_info
13
14
14
15
from build .builder import BuilderArgs , TokenizerArgs
15
16
from flask import Flask , request , Response
16
17
from generate import GeneratorArgs
17
18
18
19
19
- """
20
- Creates a flask app that can be used to serve the model as a chat API.
21
- """
22
- app = Flask (__name__ )
23
- # Messages and gen are kept global so they can be accessed by the flask app endpoints.
24
- messages : list = []
25
- gen : OpenAiApiGenerator = None
20
+ def create_app (args ):
21
+ """
22
+ Creates a flask app that can be used to serve the model as a chat API.
23
+ """
24
+ app = Flask (__name__ )
26
25
26
+ gen : OpenAiApiGenerator = initialize_generator (args )
27
27
28
- def _del_none (d : Union [Dict , List ]) -> Union [Dict , List ]:
29
- """Recursively delete None values from a dictionary."""
30
- if type (d ) is dict :
31
- return {k : _del_none (v ) for k , v in d .items () if v }
32
- elif type (d ) is list :
33
- return [_del_none (v ) for v in d if v ]
34
- return d
28
+ def _del_none (d : Union [Dict , List ]) -> Union [Dict , List ]:
29
+ """Recursively delete None values from a dictionary."""
30
+ if type (d ) is dict :
31
+ return {k : _del_none (v ) for k , v in d .items () if v }
32
+ elif type (d ) is list :
33
+ return [_del_none (v ) for v in d if v ]
34
+ return d
35
35
36
+ @app .route ("/chat" , methods = ["POST" ])
37
+ def chat_endpoint ():
38
+ """
39
+ Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
40
+ This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
36
41
37
- @app .route ("/chat" , methods = ["POST" ])
38
- def chat_endpoint ():
39
- """
40
- Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
41
- This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
42
+ ** Warning ** : Not all arguments of the CompletionRequest are consumed.
42
43
43
- ** Warning ** : Not all arguments of the CompletionRequest are consumed .
44
+ See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details .
44
45
45
- See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
46
+ If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
47
+ a single CompletionResponse object will be returned.
48
+ """
46
49
47
- If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
48
- a single CompletionResponse object will be returned.
49
- """
50
+ print (" === Completion Request ===" )
50
51
51
- print (" === Completion Request ===" )
52
+ # Parse the request in to a CompletionRequest object
53
+ data = request .get_json ()
54
+ req = CompletionRequest (** data )
52
55
53
- # Parse the request in to a CompletionRequest object
54
- data = request .get_json ()
55
- req = CompletionRequest (** data )
56
+ if data .get ("stream" ) == "true" :
56
57
57
- # Add the user message to our internal message history.
58
- messages . append ( UserMessage ( ** req . messages [ - 1 ]))
58
+ def chunk_processor ( chunked_completion_generator ):
59
+ """Inline function for postprocessing CompletionResponseChunk objects.
59
60
60
- if data .get ("stream" ) == "true" :
61
+ Here, we just jsonify the chunk and yield it as a string.
62
+ """
63
+ for chunk in chunked_completion_generator :
64
+ if (next_tok := chunk .choices [0 ].delta .content ) is None :
65
+ next_tok = ""
66
+ print (next_tok , end = "" )
67
+ yield json .dumps (_del_none (asdict (chunk )))
61
68
62
- def chunk_processor (chunked_completion_generator ):
63
- """Inline function for postprocessing CompletionResponseChunk objects.
69
+ return Response (
70
+ chunk_processor (gen .chunked_completion (req )),
71
+ mimetype = "text/event-stream" ,
72
+ )
73
+ else :
74
+ response = gen .sync_completion (req )
64
75
65
- Here, we just jsonify the chunk and yield it as a string.
66
- """
67
- messages .append (AssistantMessage (content = "" ))
68
- for chunk in chunked_completion_generator :
69
- if (next_tok := chunk .choices [0 ].delta .content ) is None :
70
- next_tok = ""
71
- messages [- 1 ].content += next_tok
72
- print (next_tok , end = "" )
73
- yield json .dumps (_del_none (asdict (chunk )))
76
+ return json .dumps (_del_none (asdict (response )))
74
77
75
- return Response (
76
- chunk_processor (gen .chunked_completion (req )), mimetype = "text/event-stream"
77
- )
78
- else :
79
- response = gen .sync_completion (req )
78
+ @app .route ("/models" , methods = ["GET" ])
79
+ def models_endpoint ():
80
+ return json .dumps (asdict (get_model_info_list (args )))
80
81
81
- messages .append (response .choices [0 ].message )
82
- print (messages [- 1 ].content )
82
+ @app .route ("/models/<model_id>" , methods = ["GET" ])
83
+ def models_retrieve_endpoint (model_id ):
84
+ if response := retrieve_model_info (args , model_id ):
85
+ return json .dumps (asdict (response ))
86
+ else :
87
+ return "Model not found" , 404
83
88
84
- return json . dumps ( _del_none ( asdict ( response )))
89
+ return app
85
90
86
91
87
92
def initialize_generator (args ) -> OpenAiApiGenerator :
@@ -103,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator:
103
108
104
109
105
110
def main (args ):
106
- global gen
107
- gen = initialize_generator (args )
111
+ app = create_app (args )
108
112
app .run ()
0 commit comments