Skip to content

Commit

Permalink
[WIP] added function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Jul 11, 2023
1 parent c331b54 commit 2a10717
Show file tree
Hide file tree
Showing 86 changed files with 3,708 additions and 2,592 deletions.
4 changes: 3 additions & 1 deletion app/auth/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ async def login(
) -> Response:
if len(username) < 3:
"""Form data validation"""
raise FormValidationError({"username": "Ensure username has at least 03 characters"})
raise FormValidationError(
{"username": "Ensure username has at least 03 characters"}
)

if username == config.mysql_user and password == config.mysql_password:
"""Save `username` in session"""
Expand Down
17 changes: 12 additions & 5 deletions app/common/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from starlette_admin.views import DropDown, Link

from app.auth.admin import MyAuthProvider
from app.common.app_settings_llama_cpp import monitor_llama_cpp_server
from app.common.config import JWT_SECRET, Config
from app.database.connection import cache, db
from app.database.schemas.auth import ApiKeys, ApiWhiteLists, Users
from app.dependencies import USER_DEPENDENCY, api_service_dependency
Expand All @@ -25,6 +23,9 @@
from app.utils.logger import ApiLogger
from app.viewmodels.admin import ApiKeyAdminView, UserAdminView

from .app_settings_llama_cpp import monitor_llama_cpp_server
from .config import JWT_SECRET, Config


async def on_startup():
"""
Expand Down Expand Up @@ -163,9 +164,13 @@ def create_app(config: Config) -> FastAPI:
middlewares=[Middleware(SessionMiddleware, secret_key=JWT_SECRET)],
)
admin.add_view(UserAdminView(Users, icon="fa fa-users", label="Users"))
admin.add_view(ApiKeyAdminView(ApiKeys, icon="fa fa-key", label="API Keys"))
admin.add_view(
ModelView(ApiWhiteLists, icon="fa fa-list", label="API White Lists")
ApiKeyAdminView(ApiKeys, icon="fa fa-key", label="API Keys")
)
admin.add_view(
ModelView(
ApiWhiteLists, icon="fa fa-list", label="API White Lists"
)
)
admin.add_view(
DropDown(
Expand All @@ -187,7 +192,9 @@ def create_app(config: Config) -> FastAPI:
Trusted host middleware: Allowed host only
"""

new_app.add_middleware(dispatch=access_control, middleware_class=BaseHTTPMiddleware)
new_app.add_middleware(
dispatch=access_control, middleware_class=BaseHTTPMiddleware
)
new_app.add_middleware(
CORSMiddleware,
allow_origins=config.allowed_sites,
Expand Down
13 changes: 7 additions & 6 deletions app/common/app_settings_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
import requests
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from app.common.config import Config

from app.shared import Shared
from app.utils.logger import ApiLogger

from .config import Config


def check_health(url: str, retry_count: int = 3) -> bool:
def check_health(url: str) -> bool:
"""Check if the given url is available or not"""
try:
schema = parse.urlparse(url).scheme
netloc = parse.urlparse(url).netloc
for _ in range(retry_count):
if requests.get(f"{schema}://{netloc}/health").status_code == 200:
return True
return False
if requests.get(f"{schema}://{netloc}/health").status_code != 200:
return False
return True
except Exception:
return False

Expand Down
10 changes: 8 additions & 2 deletions app/errors/api_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import dataclass
from typing import Optional

from fastapi.exceptions import HTTPException
from sqlalchemy.exc import OperationalError

from app.common.config import MAX_API_KEY, MAX_API_WHITELIST


Expand All @@ -26,7 +28,9 @@ def __init__(
ex: Optional[Exception] = None,
):
self.status_code = status_code
self.code = error_codes(status_code=status_code, internal_code=internal_code)
self.code = error_codes(
status_code=status_code, internal_code=internal_code
)
self.msg = msg
self.detail = detail
self.ex = ex
Expand All @@ -38,7 +42,9 @@ def __call__(
ex: Optional[Exception] = None,
) -> "APIException":
if (
self.msg is not None and self.detail is not None and lazy_format is not None
self.msg is not None
and self.detail is not None
and lazy_format is not None
): # lazy format for msg and detail
self.msg = self.msg.format(**lazy_format)
self.detail = self.detail.format(**lazy_format)
Expand Down
18 changes: 9 additions & 9 deletions app/errors/chat_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any


class ChatException(Exception): # Base exception for chat
def __init__(self, *, msg: str | None = None) -> None:
self.msg = msg
Expand Down Expand Up @@ -64,13 +67,10 @@ def __init__(self, *, msg: str | None = None) -> None:
super().__init__(msg=msg)


class ChatStreamingInterruptedException(ChatInterruptedException):
def __init__(self, *, msg: str | None = None) -> None:
self.msg = msg
super().__init__(msg=msg)
class ChatFunctionCallException(ChatException):
"""Raised when function is called."""


class ChatGeneralInterruptedException(ChatInterruptedException):
def __init__(self, *, msg: str | None = None) -> None:
self.msg = msg
super().__init__(msg=msg)
def __init__(self, *, func_name: str, func_kwargs: dict[str, Any]) -> None:
self.func_name = func_name
self.func_kwargs = func_kwargs
super().__init__(msg=f"Function {func_name}({func_kwargs}) is called.")
27 changes: 16 additions & 11 deletions app/middlewares/token_validator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from time import time

from fastapi import HTTPException
from starlette.datastructures import QueryParams, Headers
from starlette.datastructures import Headers, QueryParams
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from app.common.config import (
EXCEPT_PATH_LIST,
EXCEPT_PATH_REGEX,
)

from app.common.config import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
from app.database.crud.api_keys import get_api_key_and_owner
from app.errors.api_exceptions import (
APIException,
InternalServerError,
Responses_400,
Responses_401,
InternalServerError,
exception_handler,
)
from app.models.base_models import UserToken
Expand Down Expand Up @@ -121,21 +120,27 @@ async def access_control(request: Request, call_next: RequestResponseEndpoint):
headers=request.headers,
cookies=request.cookies,
)
response: Response = await call_next(request) # actual endpoint response
response: Response = await call_next(
request
) # actual endpoint response

except Exception as exception: # If any error occurs...
error: HTTPException | InternalServerError | APIException = exception_handler(
error=exception
error: HTTPException | InternalServerError | APIException = (
exception_handler(error=exception)
)
response: Response = JSONResponse(
status_code=error.status_code,
content={
"status": error.status_code,
"msg": error.msg if not isinstance(error, HTTPException) else None,
"msg": error.msg
if not isinstance(error, HTTPException)
else None,
"detail": error.detail
if not isinstance(error, HTTPException)
else error.detail,
"code": error.code if not isinstance(error, HTTPException) else None,
"code": error.code
if not isinstance(error, HTTPException)
else None,
},
)
ApiLogger.clog(
Expand Down
26 changes: 20 additions & 6 deletions app/middlewares/trusted_hosts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Sequence

from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send

from app.errors.api_exceptions import Responses_500


Expand All @@ -14,18 +16,28 @@ def __init__(
www_redirect: bool = True,
):
self.app: ASGIApp = app
self.allowed_hosts: list = ["*"] if allowed_hosts is None else list(allowed_hosts)
self.allow_any: bool = "*" in allowed_hosts if allowed_hosts is not None else True
self.allowed_hosts: list = (
["*"] if allowed_hosts is None else list(allowed_hosts)
)
self.allow_any: bool = (
"*" in allowed_hosts if allowed_hosts is not None else True
)
self.www_redirect: bool = www_redirect
self.except_path: list = [] if except_path is None else list(except_path)
self.except_path: list = (
[] if except_path is None else list(except_path)
)
if allowed_hosts is not None:
for allowed_host in allowed_hosts:
if "*" in allowed_host[1:]:
raise Responses_500.middleware_exception
if (allowed_host.startswith("*") and allowed_host != "*") and not allowed_host.startswith("*."):
if (
allowed_host.startswith("*") and allowed_host != "*"
) and not allowed_host.startswith("*."):
raise Responses_500.middleware_exception

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
if self.allow_any or scope["type"] not in (
"http",
"websocket",
Expand Down Expand Up @@ -57,5 +69,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
redirect_url = url.replace(netloc="www." + url.netloc)
response = RedirectResponse(url=str(redirect_url))
else:
response = PlainTextResponse("Invalid host header", status_code=400)
response = PlainTextResponse(
"Invalid host header", status_code=400
)
await response(scope, receive, send)
22 changes: 0 additions & 22 deletions app/mixins/commands/browsing.py

This file was deleted.

15 changes: 0 additions & 15 deletions app/mixins/commands/summarize.py

This file was deleted.

65 changes: 0 additions & 65 deletions app/mixins/commands/vectorstore.py

This file was deleted.

Loading

0 comments on commit 2a10717

Please sign in to comment.