Skip to content

Commit a83c682

Browse files
committed
Add models endpoint
1 parent 6401f55 commit a83c682

File tree

3 files changed

+164
-71
lines changed

3 files changed

+164
-71
lines changed

api/api.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ class StreamOptions:
8686

8787
include_usage: bool = False
8888

89+
8990
@dataclass
9091
class ResponseFormat:
9192
type: Optional[str] = None
9293

94+
9395
@dataclass
9496
class CompletionRequest:
9597
"""A full chat completion request.
@@ -99,25 +101,26 @@ class CompletionRequest:
99101

100102
messages: List[_AbstractMessage]
101103
model: str
102-
frequency_penalty: float = 0.0 # unimplemented
103-
logit_bias: Optional[Dict[str, float]] = None # unimplemented
104-
logprobs: Optional[bool] = None # unimplemented
105-
top_logprobs: Optional[int] = None # unimplemented
106-
max_tokens: Optional[int] = None # unimplemented
104+
frequency_penalty: float = 0.0 # unimplemented
105+
logit_bias: Optional[Dict[str, float]] = None # unimplemented
106+
logprobs: Optional[bool] = None # unimplemented
107+
top_logprobs: Optional[int] = None # unimplemented
108+
max_tokens: Optional[int] = None # unimplemented
107109
n: int = 1
108-
presence_penalty: float = 0 # unimplemented
109-
response_format: Optional[ResponseFormat] = None # unimplemented
110-
seed: Optional[int] = None # unimplemented
111-
service_tier: Optional[str] = None # unimplemented
112-
stop: Optional[List[str]] = None # unimplemented
110+
presence_penalty: float = 0 # unimplemented
111+
response_format: Optional[ResponseFormat] = None # unimplemented
112+
seed: Optional[int] = None # unimplemented
113+
service_tier: Optional[str] = None # unimplemented
114+
stop: Optional[List[str]] = None # unimplemented
113115
stream: bool = False
114-
stream_options: Optional[StreamOptions] = None # unimplemented
115-
temperature: Optional[float] = 1.0 # unimplemented
116-
top_p: Optional[float] = 1.0 # unimplemented
117-
tools: Optional[List[Any]] = None # unimplemented
118-
tool_choice: Optional[Union[str, Any]] = None # unimplemented
119-
parallel_tool_calls: Optional[bool] = None # unimplemented
120-
user: Optional[str] = None # unimplemented
116+
stream_options: Optional[StreamOptions] = None # unimplemented
117+
temperature: Optional[float] = 1.0 # unimplemented
118+
top_p: Optional[float] = 1.0 # unimplemented
119+
tools: Optional[List[Any]] = None # unimplemented
120+
tool_choice: Optional[Union[str, Any]] = None # unimplemented
121+
parallel_tool_calls: Optional[bool] = None # unimplemented
122+
user: Optional[str] = None # unimplemented
123+
121124

122125
@dataclass
123126
class CompletionChoice:
@@ -155,7 +158,7 @@ class CompletionResponse:
155158
choices: List[CompletionChoice]
156159
created: int
157160
model: str
158-
system_fingerprint: str
161+
system_fingerprint: str
159162
service_tier: Optional[str] = None
160163
usage: Optional[UsageStats] = None
161164
object: str = "chat.completion"

api/models.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
from dataclasses import dataclass
10+
from pwd import getpwuid
11+
from typing import List, Union
12+
13+
from download import is_model_downloaded, load_model_configs
14+
15+
"""Helper functions for the OpenAI API Models endpoint.
16+
17+
See https://platform.openai.com/docs/api-reference/models for the full specification and details.
18+
Please create an issue if anything doesn't match the specification.
19+
"""
20+
21+
22+
@dataclass
23+
class ModelInfo:
24+
"""The Model object per the OpenAI API specification containing information about a model.
25+
26+
See https://platform.openai.com/docs/api-reference/models/object for more details.
27+
"""
28+
29+
id: str
30+
created: int
31+
owner: str
32+
object: str = "model"
33+
34+
35+
@dataclass
36+
class ModelInfoList:
37+
"""A list of ModelInfo objects."""
38+
39+
data: List[ModelInfo]
40+
object: str = "list"
41+
42+
43+
def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
44+
"""Implementation of the OpenAI API Retrieve Model endpoint.
45+
46+
See https://platform.openai.com/docs/api-reference/models/retrieve
47+
48+
Inputs:
49+
args: command line arguments
50+
model_id: the id of the model requested
51+
52+
Returns:
53+
ModelInfo describing the specified if it is downloaded, None otherwise.
54+
"""
55+
if model_config := load_model_configs().get(model_id):
56+
if is_model_downloaded(model_id, args.model_directory):
57+
path = args.model_directory / model_config.name
58+
created = int(os.path.getctime(path))
59+
owner = getpwuid(os.stat(path).st_uid).pw_name
60+
61+
return ModelInfo(id=model_config.name, created=created, owner=owner)
62+
return None
63+
return None
64+
65+
66+
def get_model_info_list(args) -> ModelInfo:
67+
"""Implementation of the OpenAI API List Models endpoint.
68+
69+
See https://platform.openai.com/docs/api-reference/models/list
70+
71+
Inputs:
72+
args: command line arguments
73+
74+
Returns:
75+
ModelInfoList describing all downloaded models.
76+
"""
77+
data = []
78+
for model_id, model_config in load_model_configs().items():
79+
if is_model_downloaded(model_id, args.model_directory):
80+
path = args.model_directory / model_config.name
81+
created = int(os.path.getctime(path))
82+
owner = getpwuid(os.stat(path).st_uid).pw_name
83+
84+
data.append(ModelInfo(id=model_config.name, created=created, owner=owner))
85+
response = ModelInfoList(data=data)
86+
return response

server.py

+57-53
Original file line numberDiff line numberDiff line change
@@ -9,79 +9,84 @@
99
from dataclasses import asdict
1010
from typing import Dict, List, Union
1111

12-
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
12+
from api.api import CompletionRequest, OpenAiApiGenerator
13+
from api.models import get_model_info_list, retrieve_model_info
1314

1415
from build.builder import BuilderArgs, TokenizerArgs
1516
from flask import Flask, request, Response
1617
from generate import GeneratorArgs
1718

1819

19-
"""
20-
Creates a flask app that can be used to serve the model as a chat API.
21-
"""
22-
app = Flask(__name__)
23-
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
24-
messages: list = []
25-
gen: OpenAiApiGenerator = None
20+
def create_app(args):
21+
"""
22+
Creates a flask app that can be used to serve the model as a chat API.
23+
"""
24+
app = Flask(__name__)
2625

26+
gen: OpenAiApiGenerator = initialize_generator(args)
2727

28-
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
29-
"""Recursively delete None values from a dictionary."""
30-
if type(d) is dict:
31-
return {k: _del_none(v) for k, v in d.items() if v}
32-
elif type(d) is list:
33-
return [_del_none(v) for v in d if v]
34-
return d
28+
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
29+
"""Recursively delete None values from a dictionary."""
30+
if type(d) is dict:
31+
return {k: _del_none(v) for k, v in d.items() if v}
32+
elif type(d) is list:
33+
return [_del_none(v) for v in d if v]
34+
return d
3535

36+
@app.route("/chat", methods=["POST"])
37+
def chat_endpoint():
38+
"""
39+
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
40+
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
3641
37-
@app.route("/chat", methods=["POST"])
38-
def chat_endpoint():
39-
"""
40-
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
41-
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
42+
** Warning ** : Not all arguments of the CompletionRequest are consumed.
4243
43-
** Warning ** : Not all arguments of the CompletionRequest are consumed.
44+
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
4445
45-
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
46+
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
47+
a single CompletionResponse object will be returned.
48+
"""
4649

47-
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
48-
a single CompletionResponse object will be returned.
49-
"""
50+
print(" === Completion Request ===")
5051

51-
print(" === Completion Request ===")
52+
# Parse the request in to a CompletionRequest object
53+
data = request.get_json()
54+
req = CompletionRequest(**data)
5255

53-
# Parse the request in to a CompletionRequest object
54-
data = request.get_json()
55-
req = CompletionRequest(**data)
56+
if data.get("stream") == "true":
5657

57-
# Add the user message to our internal message history.
58-
messages.append(UserMessage(**req.messages[-1]))
58+
def chunk_processor(chunked_completion_generator):
59+
"""Inline function for postprocessing CompletionResponseChunk objects.
5960
60-
if data.get("stream") == "true":
61+
Here, we just jsonify the chunk and yield it as a string.
62+
"""
63+
for chunk in chunked_completion_generator:
64+
if (next_tok := chunk.choices[0].delta.content) is None:
65+
next_tok = ""
66+
print(next_tok, end="")
67+
yield json.dumps(_del_none(asdict(chunk)))
6168

62-
def chunk_processor(chunked_completion_generator):
63-
"""Inline function for postprocessing CompletionResponseChunk objects.
69+
return Response(
70+
chunk_processor(gen.chunked_completion(req)),
71+
mimetype="text/event-stream",
72+
)
73+
else:
74+
response = gen.sync_completion(req)
6475

65-
Here, we just jsonify the chunk and yield it as a string.
66-
"""
67-
messages.append(AssistantMessage(content=""))
68-
for chunk in chunked_completion_generator:
69-
if (next_tok := chunk.choices[0].delta.content) is None:
70-
next_tok = ""
71-
messages[-1].content += next_tok
72-
print(next_tok, end="")
73-
yield json.dumps(_del_none(asdict(chunk)))
76+
return json.dumps(_del_none(asdict(response)))
7477

75-
return Response(
76-
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
77-
)
78-
else:
79-
response = gen.sync_completion(req)
78+
@app.route("/models", methods=["GET"])
79+
def models_endpoint():
80+
return json.dumps(asdict(get_model_info_list(args)))
8081

81-
messages.append(response.choices[0].message)
82-
print(messages[-1].content)
82+
@app.route("/models/<model_id>", methods=["GET"])
83+
def models_retrieve_endpoint(model_id):
84+
if response := retrieve_model_info(args, model_id):
85+
return json.dumps(asdict(response))
86+
else:
87+
return "Model not found", 404
8388

84-
return json.dumps(_del_none(asdict(response)))
89+
return app
8590

8691

8792
def initialize_generator(args) -> OpenAiApiGenerator:
@@ -103,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator:
103108

104109

105110
def main(args):
106-
global gen
107-
gen = initialize_generator(args)
111+
app = create_app(args)
108112
app.run()

0 commit comments

Comments
 (0)