Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openai api models endpoint #1000

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os

from dataclasses import dataclass
from pwd import getpwuid
from typing import List, Union

from download import is_model_downloaded, load_model_configs

"""Helper functions for the OpenAI API Models endpoint.

See https://platform.openai.com/docs/api-reference/models for the full specification and details.
Please create an issue if anything doesn't match the specification.
"""


@dataclass
class ModelInfo:
"""The Model object per the OpenAI API specification containing information about a model.

See https://platform.openai.com/docs/api-reference/models/object for more details.
"""

id: str
created: int
owner: str
object: str = "model"


@dataclass
class ModelInfoList:
"""A list of ModelInfo objects."""

data: List[ModelInfo]
object: str = "list"


def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
"""Implementation of the OpenAI API Retrieve Model endpoint.

See https://platform.openai.com/docs/api-reference/models/retrieve

Inputs:
args: command line arguments
model_id: the id of the model requested

Returns:
ModelInfo describing the specified if it is downloaded, None otherwise.
"""
if model_config := load_model_configs().get(model_id):
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
created = int(os.path.getctime(path))
owner = getpwuid(os.stat(path).st_uid).pw_name

return ModelInfo(id=model_config.name, created=created, owner=owner)
return None
return None


def get_model_info_list(args) -> ModelInfo:
"""Implementation of the OpenAI API List Models endpoint.

See https://platform.openai.com/docs/api-reference/models/list

Inputs:
args: command line arguments

Returns:
ModelInfoList describing all downloaded models.
"""
data = []
for model_id, model_config in load_model_configs().items():
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
created = int(os.path.getctime(path))
owner = getpwuid(os.stat(path).st_uid).pw_name

data.append(ModelInfo(id=model_config.name, created=created, owner=owner))
response = ModelInfoList(data=data)
return response
110 changes: 57 additions & 53 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,79 +9,84 @@
from dataclasses import asdict
from typing import Dict, List, Union

from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
from api.api import CompletionRequest, OpenAiApiGenerator
from api.models import get_model_info_list, retrieve_model_info

from build.builder import BuilderArgs, TokenizerArgs
from flask import Flask, request, Response
from generate import GeneratorArgs


"""
Creates a flask app that can be used to serve the model as a chat API.
"""
app = Flask(__name__)
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
messages: list = []
gen: OpenAiApiGenerator = None
def create_app(args):
"""
Creates a flask app that can be used to serve the model as a chat API.
"""
app = Flask(__name__)

gen: OpenAiApiGenerator = initialize_generator(args)

def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
"""Recursively delete None values from a dictionary."""
if type(d) is dict:
return {k: _del_none(v) for k, v in d.items() if v}
elif type(d) is list:
return [_del_none(v) for v in d if v]
return d
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
"""Recursively delete None values from a dictionary."""
if type(d) is dict:
return {k: _del_none(v) for k, v in d.items() if v}
elif type(d) is list:
return [_del_none(v) for v in d if v]
return d

@app.route("/chat", methods=["POST"])
def chat_endpoint():
"""
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)

@app.route("/chat", methods=["POST"])
def chat_endpoint():
"""
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
** Warning ** : Not all arguments of the CompletionRequest are consumed.

** Warning ** : Not all arguments of the CompletionRequest are consumed.
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.

See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
a single CompletionResponse object will be returned.
"""

If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
a single CompletionResponse object will be returned.
"""
print(" === Completion Request ===")

print(" === Completion Request ===")
# Parse the request in to a CompletionRequest object
data = request.get_json()
req = CompletionRequest(**data)

# Parse the request in to a CompletionRequest object
data = request.get_json()
req = CompletionRequest(**data)
if data.get("stream") == "true":

# Add the user message to our internal message history.
messages.append(UserMessage(**req.messages[-1]))
def chunk_processor(chunked_completion_generator):
"""Inline function for postprocessing CompletionResponseChunk objects.

if data.get("stream") == "true":
Here, we just jsonify the chunk and yield it as a string.
"""
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
print(next_tok, end="")
yield json.dumps(_del_none(asdict(chunk)))

def chunk_processor(chunked_completion_generator):
"""Inline function for postprocessing CompletionResponseChunk objects.
return Response(
chunk_processor(gen.chunked_completion(req)),
mimetype="text/event-stream",
)
else:
response = gen.sync_completion(req)

Here, we just jsonify the chunk and yield it as a string.
"""
messages.append(AssistantMessage(content=""))
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
messages[-1].content += next_tok
print(next_tok, end="")
yield json.dumps(_del_none(asdict(chunk)))
return json.dumps(_del_none(asdict(response)))

return Response(
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
)
else:
response = gen.sync_completion(req)
@app.route("/models", methods=["GET"])
def models_endpoint():
return json.dumps(asdict(get_model_info_list(args)))

messages.append(response.choices[0].message)
print(messages[-1].content)
@app.route("/models/<model_id>", methods=["GET"])
def models_retrieve_endpoint(model_id):
if response := retrieve_model_info(args, model_id):
return json.dumps(asdict(response))
else:
return "Model not found", 404

return json.dumps(_del_none(asdict(response)))
return app


def initialize_generator(args) -> OpenAiApiGenerator:
Expand All @@ -103,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator:


def main(args):
global gen
gen = initialize_generator(args)
app = create_app(args)
app.run()
Loading