Skip to content

Commit

Permalink
Function server
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed May 23, 2024
1 parent e11d4c7 commit 8a76cbf
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 64 deletions.
1 change: 1 addition & 0 deletions openai_server/function_server.py
2 changes: 1 addition & 1 deletion openai_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Dict, Optional, Literal, Union
from pydantic import BaseModel, Field

from fastapi import FastAPI, Header, HTTPException
from fastapi import FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi import Request, Depends
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand Down
16 changes: 11 additions & 5 deletions openai_server/server_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ def run_server(host: str = '0.0.0.0',
# https://docs.gunicorn.org/en/stable/design.html#how-many-workers
workers: int = 1,
app: Union[str, FastAPI] = None,
is_openai_server: bool = True,
):
if workers == 0:
workers = min(16, os.cpu_count() * 2 + 1)
assert app is not None

name = 'OpenAI' if is_openai_server else 'Function'

os.environ['GRADIO_PREFIX'] = gradio_prefix or 'http'
os.environ['GRADIO_SERVER_HOST'] = gradio_host or 'localhost'
os.environ['GRADIO_SERVER_PORT'] = gradio_port or '7860'
Expand All @@ -54,8 +57,8 @@ def run_server(host: str = '0.0.0.0',

prefix = 'https' if ssl_keyfile and ssl_certfile else 'http'
from openai_server.log import logger
logger.info(f'OpenAI API URL: {prefix}://{host}:{port}')
logger.info(f'OpenAI API key: {server_api_key}')
logger.info(f'{name} API URL: {prefix}://{host}:{port}')
logger.info(f'{name} API key: {server_api_key}')

logging.getLogger("uvicorn.error").propagate = False

Expand All @@ -67,9 +70,12 @@ def run_server(host: str = '0.0.0.0',


def run(wait=True, **kwargs):
assert 'is_openai_server' in kwargs
name = 'OpenAI' if kwargs['is_openai_server'] else 'Function'
print(kwargs)

if kwargs['workers'] > 1 or kwargs['workers'] == 0:
print("Multi-worker OpenAI Proxy uvicorn: %s" % kwargs['workers'])
print(f"Multi-worker {name} Proxy uvicorn: {kwargs['workers']}")
# avoid CUDA forking
command = ['python', 'openai_server/server_start.py']
# Convert the kwargs to command line arguments
Expand All @@ -81,10 +87,10 @@ def run(wait=True, **kwargs):
for c in iter(lambda: process.stdout.read(1), b''):
sys.stdout.write(c.decode('utf-8', errors='replace')) # Ensure decoding from bytes to str
elif wait:
print("Single-worker OpenAI Proxy uvicorn in this thread: %s" % kwargs['workers'])
print(f"Single-worker {name} Proxy uvicorn in this thread: {kwargs['workers']}")
run_server(**kwargs)
else:
print("Single-worker OpenAI Proxy uvicorn in new thread: %s" % kwargs['workers'])
print(f"Single-worker {name} Proxy uvicorn in new thread: {kwargs['workers']}")
Thread(target=run_server, kwargs=kwargs, daemon=True).start()


Expand Down
37 changes: 37 additions & 0 deletions src/function_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import requests
import json


def execute_function_on_server(host: str, port: int, function_name: str, args: list, kwargs: dict, use_disk: bool):
url = f"http://{host}:{port}/execute_function/"
payload = {
"function_name": function_name,
"args": args,
"kwargs": kwargs,
"use_disk": use_disk,
}
response = requests.post(url, json=payload)
if response.status_code == 200:
return response.json()
else:
return {"error": response.json()["detail"]}


def read_result_from_disk(file_path: str):
with open(file_path, "r") as f:
result = json.load(f)
return result


def function_client(host, port, function_name, args, kwargs, use_disk):
execute_result = execute_function_on_server(host, port, function_name, args, kwargs, use_disk)
if "error" in execute_result:
print(f"Error: {execute_result['error']}")
else:
if use_disk:
file_path = execute_result["file_path"]
print(f"Result saved at: {file_path}")
result_from_disk = read_result_from_disk(file_path)
print("Result read from disk:", result_from_disk)
else:
print("Result received directly:", execute_result["result"])
111 changes: 111 additions & 0 deletions src/function_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import sys
import json
import inspect
import typing
from traceback import print_exception
from typing import Union

from pydantic import BaseModel

from fastapi import FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi import Request, Depends
from fastapi.responses import JSONResponse, Response, StreamingResponse
from sse_starlette import EventSourceResponse
from starlette.responses import PlainTextResponse

sys.path.append('src')


# similar to openai_server/server.py
def verify_api_key(authorization: str = Header(None)) -> None:
server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY')
if server_api_key == 'EMPTY':
# dummy case since '' cannot be handled
return
if server_api_key and (authorization is None or authorization != f"Bearer {server_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")


app = FastAPI()
check_key = [Depends(verify_api_key)]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)


class InvalidRequestError(Exception):
pass


class FunctionRequest(BaseModel):
function_name: str
args: list
kwargs: dict
use_disk: bool = False


# Example functions
def example_function1(x, y):
return x + y


def example_function2(path: str):
if not os.path.exists(path):
raise ValueError("Path does not exist")
if not os.path.isdir(path):
raise ValueError("Path is not a directory")
docs = [f for f in os.listdir(path) if f.endswith('.doc') or f.endswith('.docx')]
return {"documents": docs}


@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)


@app.exception_handler(Exception)
async def validation_exception_handler(request, exc):
print_exception(exc)
exc2 = InvalidRequestError(str(exc))
return PlainTextResponse(str(exc2), status_code=400)


@app.options("/", dependencies=check_key)
async def options_route():
return JSONResponse(content="OK")


@app.post("/execute_function/")
def execute_function(request: FunctionRequest):
# Mapping of function names to function objects
FUNCTIONS = {
'example_function1': example_function1,
'example_function2': example_function2,
}
try:
# Fetch the function from the function map
func = FUNCTIONS.get(request.function_name)
if not func:
raise ValueError("Function not found")

# Call the function with args and kwargs
result = func(*request.args, **request.kwargs)

if request.use_disk:
# Save the result to a file on the shared disk
file_path = "/path/to/shared/disk/function_result.json"
with open(file_path, "w") as f:
json.dump(result, f)
return {"status": "success", "file_path": file_path}
else:
# Return the result directly
return {"status": "success", "result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
15 changes: 13 additions & 2 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,15 @@ def main(
cli: bool = False,
cli_loop: bool = True,
gradio: bool = True,

openai_server: bool = True,
openai_port: int = 5001 if sys.platform == "darwin" else 5000,
openai_workers: int = 1,

function_server: bool = False,
function_server_port: int = 5001 if sys.platform == "darwin" else 5000,
function_server_workers: int = 1,

gradio_offline_level: int = 0,
server_name: str = "0.0.0.0",
share: bool = False,
Expand Down Expand Up @@ -796,10 +801,16 @@ def main(
:param cli: whether to use CLI (non-gradio) interface.
:param cli_loop: whether to loop for CLI (False usually only for testing)
:param gradio: whether to enable gradio, or to enable benchmark mode
:param openai_server: whether to launch OpenAI proxy server for local gradio server
Disabled if API is disabled or --auth=closed
Disabled if API is disabled
:param openai_port: port for OpenAI proxy server
:param openai_workers: number of workers for OpenAI (1 means 1 worker, 0 means all physical cores, else choose)
:param function_server: whether to launch Function server to handle document loading offloading to separate thread or forks
:param function_server_port: port for OpenAI proxy server
:param function_server_workers: number of workers for Function Server (1 means 1 worker, 0 means all physical cores, else choose)
:param gradio_offline_level: > 0, then change fonts so full offline
== 1 means backend won't need internet for fonts, but front-end UI might if font not cached
== 2 means backend and frontend don't need internet to download any fonts.
Expand Down Expand Up @@ -1702,7 +1713,7 @@ def main(
allow_api = bool(int(os.getenv('ALLOW_API', str(int(allow_api)))))

if openai_server and not allow_api:
print("Cannot enable OpenAI server when allow_api=False or auth is closed")
print("Cannot enable OpenAI server when allow_api=False")
openai_server = False

if not os.getenv('CLEAR_CLEAR_TORCH'):
Expand Down
Loading

0 comments on commit 8a76cbf

Please sign in to comment.