Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Jun 26, 2023
1 parent e1f2da1 commit ce768d4
Show file tree
Hide file tree
Showing 36 changed files with 907 additions and 591 deletions.
262 changes: 169 additions & 93 deletions app/common/app_settings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from multiprocessing import Process
from threading import Event
from threading import Thread
from time import sleep
from os import kill
from signal import SIGINT
from threading import Event, Thread
from urllib import parse

import requests
Expand All @@ -28,11 +26,12 @@
from app.shared import Shared
from app.utils.chat.managers.cache import CacheManager
from app.utils.js_initializer import js_url_initializer
from app.utils.logger import api_logger
from app.utils.logger import ApiLogger
from app.viewmodels.admin import ApiKeyAdminView, UserAdminView


def check_health(url: str) -> bool:
"""Check if the given url is available or not"""
try:
schema = parse.urlparse(url).scheme
netloc = parse.urlparse(url).netloc
Expand All @@ -43,48 +42,183 @@ def check_health(url: str) -> bool:
return False


def start_llama_cpp_server():
def start_llama_cpp_server(shared: Shared):
"""Start Llama CPP server. if it is already running, terminate it first."""
from app.start_llama_cpp_server import run

if Shared().process is not None and Shared().process.is_alive():
api_logger.warning("Terminating existing Llama CPP server")
Shared().process.terminate()
Shared().process.join()
if shared.process.is_alive():
ApiLogger.cwarning("Terminating existing Llama CPP server")
shared.process.terminate()
shared.process.join()

api_logger.critical("Starting Llama CPP server")
Shared().process = Process(target=run, args=(Shared().process_terminate_signal,))
Shared().process.start()
ApiLogger.ccritical("Starting Llama CPP server")
shared.process = Process(target=run, daemon=True)
shared.process.start()


def shutdown_llama_cpp_server():
api_logger.critical("Shutting down Llama CPP server")
Shared().process_terminate_signal.set()
Shared().process.join()
def shutdown_llama_cpp_server(shared: Shared):
"""Shutdown Llama CPP server."""
ApiLogger.ccritical("Shutting down Llama CPP server")
if shared.process.is_alive() and shared.process.pid:
kill(shared.process.pid, SIGINT)
shared.process.join()


def monitor_llama_cpp_server(config: Config, terminate_signal: Event) -> None:
while not terminate_signal.is_set():
sleep(0.5)
if config.llama_cpp_completion_url:
if not check_health(config.llama_cpp_completion_url):
if config.is_llama_cpp_booting or terminate_signal.is_set():
continue
api_logger.error("Llama CPP server is not available")
config.is_llama_cpp_available = False
config.is_llama_cpp_booting = True
start_llama_cpp_server()
else:
config.is_llama_cpp_booting = False
config.is_llama_cpp_available = True
shutdown_llama_cpp_server()
def monitor_llama_cpp_server(
config: Config,
shared: Shared,
) -> None:
"""Monitors the Llama CPP server and handles server availability.
Parameters:
- `config: Config`: An object representing the server configuration.
- `shared: Shared`: An object representing shared data."""
thread_sigterm: Event = shared.thread_terminate_signal
if not config.llama_cpp_completion_url:
return
while True:
if not check_health(config.llama_cpp_completion_url):
if thread_sigterm.is_set():
break
if config.is_llama_cpp_booting:
continue
ApiLogger.cerror("Llama CPP server is not available")
config.is_llama_cpp_available = False
config.is_llama_cpp_booting = True
try:
start_llama_cpp_server(shared)
except ImportError:
ApiLogger.cerror("ImportError: Llama CPP server is not available")
return
else:
config.is_llama_cpp_booting = False
config.is_llama_cpp_available = True
shutdown_llama_cpp_server(shared)


async def on_startup():
"""
Performs necessary operations during application startup.
This function is called when the application is starting up.
It performs the following operations:
- Logs a startup message using ApiLogger.
- Retrieves the configuration object.
- Checks if the MySQL database connection is initiated and logs the status.
- Raises a ConnectionError if the Redis cache connection is not established.
- Checks if the Redis cache connection is initiated and logs the status.
- Attempts to import and set uvloop as the event loop policy, if available, and logs the result.
- Starts Llama CPP server monitoring if the Llama CPP completion URL is provided.
"""
ApiLogger.ccritical("⚙️ Booting up...")
config = Config.get()
shared = Shared()
if db.is_initiated:
ApiLogger.ccritical("MySQL DB connected!")
else:
ApiLogger.ccritical("MySQL DB connection failed!")
if cache.redis is None:
raise ConnectionError("Redis is not connected yet!")
if cache.is_initiated and await cache.redis.ping():
await CacheManager.delete_user(f"testaccount@{config.host_main}")
ApiLogger.ccritical("Redis CACHE connected!")
else:
ApiLogger.ccritical("Redis CACHE connection failed!")
try:
import asyncio

import uvloop # type: ignore

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
ApiLogger.ccritical("uvloop installed!")
except ImportError:
ApiLogger.ccritical("uvloop not installed!")

if config.llama_cpp_completion_url:
# Start Llama CPP server monitoring
ApiLogger.ccritical("Llama CPP server monitoring started!")
shared.thread = Thread(
target=monitor_llama_cpp_server,
args=(config, shared),
daemon=True,
)
shared.thread.start()


async def on_shutdown():
"""
Performs necessary operations during application shutdown.
This function is called when the application is shutting down.
It performs the following operations:
- Logs a shutdown message using ApiLogger.
- Sets terminate signals for shared threads and processes.
- Shuts down the process manager, if available.
- Shuts down the process pool executor, if available.
- Terminates and joins the process, if available.
- Joins the thread, if available.
- Closes the database and cache connections.
- Logs a message indicating the closure of DB and CACHE connections.
"""
ApiLogger.ccritical("⚙️ Shutting down...")
shared = Shared()
# await CacheManager.delete_user(f"testaccount@{HOST_MAIN}")
shared.thread_terminate_signal.set()
shared.process_terminate_signal.set()

process_manager = shared._process_manager
if process_manager is not None:
process_manager.shutdown()

process_pool_executor = shared._process_pool_executor
if process_pool_executor is not None:
process_pool_executor.shutdown(wait=True)

process = shared._process
if process is not None:
process.terminate()
process.join()

thread = shared._thread
if thread is not None:
thread.join()

await db.close()
await cache.close()
ApiLogger.ccritical("DB & CACHE connection closed!")


def create_app(config: Config) -> FastAPI:
"""
Creates and configures the FastAPI application.
Args:
config (Config): The configuration object.
Returns:
FastAPI: The configured FastAPI application.
This function creates a new FastAPI application, sets the specified title, description, and version,
and adds `on_startup` and `on_shutdown` event handlers.
It then starts the database and cache connections and initializes the JavaScript URL.
If the database engine is available, it adds admin views for managing users, API keys, and API white lists.
Next, it adds the necessary middlewares for access control, CORS, and trusted hosts.
It mounts the "/chat" endpoint for serving static files, and includes routers for index, websocket,
authentication, services, users, and user services.
Finally, it sets the application's config and shared state and returns the configured application.
"""
# Initialize app & db & js
new_app = FastAPI(
title=config.app_title,
description=config.app_description,
version=config.app_version,
on_startup=[on_startup],
on_shutdown=[on_shutdown],
)
db.start(config=config)
cache.start(config=config)
Expand Down Expand Up @@ -164,64 +298,6 @@ def create_app(config: Config) -> FastAPI:
tags=["User Services"],
dependencies=[Depends(USER_DEPENDENCY)],
)

@new_app.on_event("startup")
async def startup():
if db.is_initiated:
api_logger.critical("MySQL DB connected!")
else:
api_logger.critical("MySQL DB connection failed!")
if cache.redis is None:
raise ConnectionError("Redis is not connected yet!")
if cache.is_initiated and await cache.redis.ping():
await CacheManager.delete_user(f"testaccount@{config.host_main}")
api_logger.critical("Redis CACHE connected!")
else:
api_logger.critical("Redis CACHE connection failed!")
try:
import asyncio

import uvloop # type: ignore

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
api_logger.critical("uvloop installed!")
except ImportError:
api_logger.critical("uvloop not installed!")

if config.llama_cpp_completion_url:
# Start Llama CPP server monitoring
api_logger.critical("Llama CPP server monitoring started!")
Shared().thread = Thread(
target=monitor_llama_cpp_server,
args=(config, Shared().thread_terminate_signal),
)
Shared().thread.start()

@new_app.on_event("shutdown")
async def shutdown():
# await CacheManager.delete_user(f"testaccount@{HOST_MAIN}")
Shared().thread_terminate_signal.set()
Shared().process_terminate_signal.set()

process_manager = Shared()._process_manager
if process_manager is not None:
process_manager.shutdown()

process_pool_executor = Shared()._process_pool_executor
if process_pool_executor is not None:
process_pool_executor.shutdown(wait=False)

process = Shared()._process
if process is not None:
process.terminate()
process.join()

thread = Shared()._thread
if thread is not None:
thread.join()

await db.close()
await cache.close()
api_logger.critical("DB & CACHE connection closed!")

new_app.state.config = config
new_app.state.shared = Shared()
return new_app
14 changes: 14 additions & 0 deletions app/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class DescriptionTemplates:
),
input_variables=[],
)
USER_AI__ENGLISH: PromptTemplate = PromptTemplate(
template=(
"You are a good English teacher. Any sentence that {user} says that is surrounded"
' by double quotation marks ("") is asking how you interpret that sentence. Pleas'
"e analyze and explain that sentence in as much detail as possible. For the rest "
"of the sentences, please respond in a way that will help {user} learn English."
),
input_variables=["user"],
)


class ChatTurnTemplates:
Expand All @@ -82,6 +91,11 @@ class ChatTurnTemplates:
input_variables=["role", "content"],
template_format="f-string",
)
ROLE_CONTENT_4: PromptTemplate = PromptTemplate(
template="###{role}: {content}\n",
input_variables=["role", "content"],
template_format="f-string",
)


class SummarizationTemplates:
Expand Down
22 changes: 22 additions & 0 deletions app/common/lotties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Self

from app.common.mixins import EnumMixin


class Lotties(EnumMixin):
CLICK = "lottie-click"
READ = "lottie-read"
SCROLL_DOWN = "lottie-scroll-down"
GO_BACK = "lottie-go-back"
SEARCH_WEB = "lottie-search-web"
SEARCH_DOC = "lottie-search-doc"
OK = "lottie-ok"
FAIL = "lottie-fail"
TRANSLATE = "lottie-translate"

def format(self, contents: str, end: bool = True) -> str:
return f"\n```{self.get_value(self)}\n{contents}" + ("\n```\n" if end else "")


if __name__ == "__main__":
print(Lotties.CLICK.format("hello"))
6 changes: 3 additions & 3 deletions app/middlewares/token_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from app.models.base_models import UserToken
from app.utils.auth.token import token_decode
from app.utils.date_utils import UTC
from app.utils.logger import api_logger
from app.utils.logger import ApiLogger
from app.utils.params_utils import hash_params


Expand Down Expand Up @@ -138,7 +138,7 @@ async def access_control(request: Request, call_next: RequestResponseEndpoint):
"code": error.code if not isinstance(error, HTTPException) else None,
},
)
api_logger.log_api(
ApiLogger.clog(
request=request,
response=response,
error=error,
Expand All @@ -150,7 +150,7 @@ async def access_control(request: Request, call_next: RequestResponseEndpoint):
else:
# Log error or service info
if url.startswith("/api/services"):
api_logger.log_api(
ApiLogger.clog(
request=request,
response=response,
cookies=request.cookies,
Expand Down
Loading

0 comments on commit ce768d4

Please sign in to comment.