forked from xtekky/gpt4free
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request xtekky#1213 from thatlukinhasguy1/main
Make the API use FastAPI instead of Flask
- Loading branch information
Showing
5 changed files
with
144 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,227 +1,162 @@ | ||
import typing | ||
from .. import BaseProvider | ||
import g4f; g4f.debug.logging = True | ||
from fastapi import FastAPI, Response, Request | ||
from typing import List, Union, Any, Dict, AnyStr | ||
from ._tokenizer import tokenize | ||
from .. import BaseProvider | ||
|
||
import time | ||
import json | ||
import random | ||
import string | ||
import logging | ||
|
||
from typing import Union | ||
from loguru import logger | ||
from waitress import serve | ||
from ._logging import hook_logging | ||
from ._tokenizer import tokenize | ||
from flask_cors import CORS | ||
from werkzeug.serving import WSGIRequestHandler | ||
from werkzeug.exceptions import default_exceptions | ||
from werkzeug.middleware.proxy_fix import ProxyFix | ||
|
||
from flask import ( | ||
Flask, | ||
jsonify, | ||
make_response, | ||
request, | ||
) | ||
import uvicorn | ||
import nest_asyncio | ||
import g4f | ||
|
||
class Api: | ||
__default_ip = '127.0.0.1' | ||
__default_port = 1337 | ||
|
||
def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False, | ||
list_ignored_providers:typing.List[typing.Union[str, BaseProvider]]=None) -> None: | ||
self.engine = engine | ||
self.debug = debug | ||
self.sentry = sentry | ||
self.list_ignored_providers = list_ignored_providers | ||
self.log_level = logging.DEBUG if debug else logging.WARN | ||
|
||
hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s') | ||
self.logger = logging.getLogger('waitress') | ||
|
||
self.app = Flask(__name__) | ||
self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1) | ||
self.app.after_request(self.__after_request) | ||
|
||
def run(self, bind_str, threads=8): | ||
host, port = self.__parse_bind(bind_str) | ||
|
||
CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [ | ||
'Content-Type', | ||
'Authorization', | ||
'X-Requested-With', | ||
'Accept', | ||
'Origin', | ||
'Access-Control-Request-Method', | ||
'Access-Control-Request-Headers', | ||
'Content-Disposition'], 'max_age': 600}}) | ||
|
||
self.app.route('/v1/models', methods=['GET'])(self.models) | ||
self.app.route('/v1/models/<model_id>', methods=['GET'])(self.model_info) | ||
|
||
self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions) | ||
self.app.route('/v1/completions', methods=['POST'])(self.completions) | ||
|
||
for ex in default_exceptions: | ||
self.app.register_error_handler(ex, self.__handle_error) | ||
|
||
if not self.debug: | ||
self.logger.warning(f'Serving on http://{host}:{port}') | ||
|
||
WSGIRequestHandler.protocol_version = 'HTTP/1.1' | ||
serve(self.app, host=host, port=port, ident=None, threads=threads) | ||
|
||
def __handle_error(self, e: Exception): | ||
self.logger.error(e) | ||
|
||
return make_response(jsonify({ | ||
'code': e.code, | ||
'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500) | ||
|
||
@staticmethod | ||
def __after_request(resp): | ||
resp.headers['X-Server'] = f'g4f/{g4f.version}' | ||
|
||
return resp | ||
|
||
def __parse_bind(self, bind_str): | ||
sections = bind_str.split(':', 2) | ||
if len(sections) < 2: | ||
list_ignored_providers: List[Union[str, BaseProvider]] = None) -> None: | ||
self.engine = engine | ||
self.debug = debug | ||
self.sentry = sentry | ||
self.list_ignored_providers = list_ignored_providers | ||
|
||
self.app = FastAPI() | ||
nest_asyncio.apply() | ||
|
||
JSONObject = Dict[AnyStr, Any] | ||
JSONArray = List[Any] | ||
JSONStructure = Union[JSONArray, JSONObject] | ||
|
||
@self.app.get("/") | ||
async def read_root(): | ||
return Response(content=json.dumps({"info": "g4f API"}, indent=4), media_type="application/json") | ||
|
||
@self.app.get("/v1") | ||
async def read_root_v1(): | ||
return Response(content=json.dumps({"info": "Go to /v1/chat/completions or /v1/models."}, indent=4), media_type="application/json") | ||
|
||
@self.app.get("/v1/models") | ||
async def models(): | ||
model_list = [{ | ||
'id': model, | ||
'object': 'model', | ||
'created': 0, | ||
'owned_by': 'g4f'} for model in g4f.Model.__all__()] | ||
|
||
return Response(content=json.dumps({ | ||
'object': 'list', | ||
'data': model_list}, indent=4), media_type="application/json") | ||
|
||
@self.app.get("/v1/models/{model_name}") | ||
async def model_info(model_name: str): | ||
try: | ||
port = int(sections[0]) | ||
return self.__default_ip, port | ||
except ValueError: | ||
return sections[0], self.__default_port | ||
|
||
return sections[0], int(sections[1]) | ||
|
||
async def home(self): | ||
return 'Hello world | https://127.0.0.1:1337/v1' | ||
|
||
async def chat_completions(self): | ||
model = request.json.get('model', 'gpt-3.5-turbo') | ||
stream = request.json.get('stream', False) | ||
messages = request.json.get('messages') | ||
|
||
logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}') | ||
|
||
config = None | ||
proxy = None | ||
|
||
try: | ||
config = json.load(open("config.json","r",encoding="utf-8")) | ||
proxy = config["proxy"] | ||
|
||
except Exception: | ||
pass | ||
|
||
if proxy != None: | ||
response = self.engine.ChatCompletion.create(model=model, | ||
stream=stream, messages=messages, | ||
ignored=self.list_ignored_providers, | ||
proxy=proxy) | ||
else: | ||
response = self.engine.ChatCompletion.create(model=model, | ||
stream=stream, messages=messages, | ||
ignored=self.list_ignored_providers) | ||
|
||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) | ||
completion_timestamp = int(time.time()) | ||
|
||
if not stream: | ||
prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) | ||
completion_tokens, _ = tokenize(response) | ||
|
||
return { | ||
'id': f'chatcmpl-{completion_id}', | ||
'object': 'chat.completion', | ||
'created': completion_timestamp, | ||
'model': model, | ||
'choices': [ | ||
{ | ||
'index': 0, | ||
'message': { | ||
'role': 'assistant', | ||
'content': response, | ||
}, | ||
'finish_reason': 'stop', | ||
} | ||
], | ||
'usage': { | ||
'prompt_tokens': prompt_tokens, | ||
'completion_tokens': completion_tokens, | ||
'total_tokens': prompt_tokens + completion_tokens, | ||
}, | ||
model_info = (g4f.ModelUtils.convert[model_name]) | ||
|
||
return Response(content=json.dumps({ | ||
'id': model_name, | ||
'object': 'model', | ||
'created': 0, | ||
'owned_by': model_info.base_provider | ||
}, indent=4), media_type="application/json") | ||
except: | ||
return Response(content=json.dumps({"error": "The model does not exist."}, indent=4), media_type="application/json") | ||
|
||
@self.app.post("/v1/chat/completions") | ||
async def chat_completions(request: Request, item: JSONStructure = None): | ||
item_data = { | ||
'model': 'gpt-3.5-turbo', | ||
'stream': False, | ||
} | ||
|
||
def streaming(): | ||
item_data.update(item or {}) | ||
model = item_data.get('model') | ||
stream = item_data.get('stream') | ||
messages = item_data.get('messages') | ||
|
||
try: | ||
for chunk in response: | ||
completion_data = { | ||
response = g4f.ChatCompletion.create(model=model, stream=stream, messages=messages) | ||
except: | ||
return Response(content=json.dumps({"error": "An error occurred while generating the response."}, indent=4), media_type="application/json") | ||
|
||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) | ||
completion_timestamp = int(time.time()) | ||
|
||
if not stream: | ||
prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) | ||
completion_tokens, _ = tokenize(response) | ||
|
||
json_data = { | ||
'id': f'chatcmpl-{completion_id}', | ||
'object': 'chat.completion', | ||
'created': completion_timestamp, | ||
'model': model, | ||
'choices': [ | ||
{ | ||
'index': 0, | ||
'message': { | ||
'role': 'assistant', | ||
'content': response, | ||
}, | ||
'finish_reason': 'stop', | ||
} | ||
], | ||
'usage': { | ||
'prompt_tokens': prompt_tokens, | ||
'completion_tokens': completion_tokens, | ||
'total_tokens': prompt_tokens + completion_tokens, | ||
}, | ||
} | ||
|
||
return Response(content=json.dumps(json_data, indent=4), media_type="application/json") | ||
|
||
def streaming(): | ||
try: | ||
for chunk in response: | ||
completion_data = { | ||
'id': f'chatcmpl-{completion_id}', | ||
'object': 'chat.completion.chunk', | ||
'created': completion_timestamp, | ||
'model': model, | ||
'choices': [ | ||
{ | ||
'index': 0, | ||
'delta': { | ||
'content': chunk, | ||
}, | ||
'finish_reason': None, | ||
} | ||
], | ||
} | ||
|
||
content = json.dumps(completion_data, separators=(',', ':')) | ||
yield f'data: {content}\n\n' | ||
time.sleep(0.03) | ||
|
||
end_completion_data = { | ||
'id': f'chatcmpl-{completion_id}', | ||
'object': 'chat.completion.chunk', | ||
'created': completion_timestamp, | ||
'model': model, | ||
'choices': [ | ||
{ | ||
'index': 0, | ||
'delta': { | ||
'content': chunk, | ||
}, | ||
'finish_reason': None, | ||
'delta': {}, | ||
'finish_reason': 'stop', | ||
} | ||
], | ||
} | ||
|
||
content = json.dumps(completion_data, separators=(',', ':')) | ||
content = json.dumps(end_completion_data, separators=(',', ':')) | ||
yield f'data: {content}\n\n' | ||
time.sleep(0.03) | ||
|
||
end_completion_data = { | ||
'id': f'chatcmpl-{completion_id}', | ||
'object': 'chat.completion.chunk', | ||
'created': completion_timestamp, | ||
'model': model, | ||
'choices': [ | ||
{ | ||
'index': 0, | ||
'delta': {}, | ||
'finish_reason': 'stop', | ||
} | ||
], | ||
} | ||
|
||
content = json.dumps(end_completion_data, separators=(',', ':')) | ||
yield f'data: {content}\n\n' | ||
|
||
logger.success(f'model: {model}, stream: {stream}') | ||
|
||
except GeneratorExit: | ||
pass | ||
|
||
return self.app.response_class(streaming(), mimetype='text/event-stream') | ||
|
||
async def completions(self): | ||
return 'not working yet', 500 | ||
|
||
async def model_info(self, model_name): | ||
model_info = (g4f.ModelUtils.convert[model_name]) | ||
|
||
return jsonify({ | ||
'id' : model_name, | ||
'object' : 'model', | ||
'created' : 0, | ||
'owned_by' : model_info.base_provider | ||
}) | ||
|
||
async def models(self): | ||
model_list = [{ | ||
'id' : model, | ||
'object' : 'model', | ||
'created' : 0, | ||
'owned_by' : 'g4f'} for model in g4f.Model.__all__()] | ||
|
||
return jsonify({ | ||
'object': 'list', | ||
'data': model_list}) | ||
|
||
except GeneratorExit: | ||
pass | ||
|
||
return Response(content=json.dumps(streaming(), indent=4), media_type="application/json") | ||
|
||
@self.app.post("/v1/completions") | ||
async def completions(): | ||
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") | ||
|
||
def run(self, ip): | ||
split_ip = ip.split(":") | ||
uvicorn.run(app=self.app, host=split_ip[0], port=int(split_ip[1]), use_colors=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.