Skip to content

Commit

Permalink
llama-cpp API server added
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Jun 19, 2023
1 parent 5393c85 commit 9b42933
Show file tree
Hide file tree
Showing 38 changed files with 1,413 additions and 394 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.env
llama_models/*
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ PRIVATE_*
venv/
*.pyc
*.log
*.bin
llama_models/ggml/*
llama_models/gptq/*
!llama_models/ggml/llama_cpp_models_here.txt
!llama_models/gptq/gptq_models_here.txt
deprecated_*
92 changes: 89 additions & 3 deletions app/common/app_settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from threading import Event
from threading import Thread
from time import sleep
from urllib import parse

import requests
from fastapi import Depends, FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.middleware import Middleware
Expand All @@ -15,14 +23,65 @@
from app.dependencies import USER_DEPENDENCY, api_service_dependency
from app.middlewares.token_validator import access_control
from app.middlewares.trusted_hosts import TrustedHostMiddleware
from app.routers import auth, index, services, users, user_services, websocket
from app.routers import auth, index, services, user_services, users, websocket
from app.shared import Shared
from app.utils.chat.managers.cache import CacheManager
from app.utils.js_initializer import js_url_initializer
from app.utils.logger import api_logger
from app.viewmodels.admin import ApiKeyAdminView, UserAdminView


def check_health(url: str) -> bool:
try:
schema = parse.urlparse(url).scheme
netloc = parse.urlparse(url).netloc
if requests.get(f"{schema}://{netloc}/health").status_code != 200:
return False
return True
except Exception:
return False


def start_llama_cpp_server():
from app.start_llama_cpp_server import run

api_logger.critical("Starting Llama CPP server")
try:
Shared().process_pool_executor.submit(
run,
terminate_event=Shared().process_terminate_signal,
)
except BrokenProcessPool as e:
api_logger.exception(f"Broken Llama CPP server: {e}")
Shared().process_pool_executor.shutdown(wait=False)
Shared().process_pool_executor = ProcessPoolExecutor()
start_llama_cpp_server()
except Exception as e:
api_logger.exception(f"Failed to start Llama CPP server: {e}")


def shutdown_llama_cpp_server():
api_logger.critical("Shutting down Llama CPP server")
Shared().process_terminate_signal.set()


def monitor_llama_cpp_server(config: Config, terminate_signal: Event) -> None:
while not terminate_signal.is_set():
sleep(0.5)
if config.llama_cpp_api_url:
if not check_health(config.llama_cpp_api_url):
if config.is_llama_cpp_booting or terminate_signal.is_set():
continue
api_logger.error("Llama CPP server is not available")
config.llama_cpp_available = False
config.is_llama_cpp_booting = True
start_llama_cpp_server()
else:
config.is_llama_cpp_booting = False
config.llama_cpp_available = True
shutdown_llama_cpp_server()


def create_app(config: Config) -> FastAPI:
# Initialize app & db & js
new_app = FastAPI(
Expand Down Expand Up @@ -132,11 +191,38 @@ async def startup():
except ImportError:
api_logger.critical("uvloop not installed!")

if config.llama_cpp_api_url:
# Start Llama CPP server monitoring
api_logger.critical("Llama CPP server monitoring started!")
Shared().thread = Thread(
target=monitor_llama_cpp_server,
args=(config, Shared().thread_terminate_signal),
)
Shared().thread.start()

@new_app.on_event("shutdown")
async def shutdown():
# await CacheManager.delete_user(f"testaccount@{HOST_MAIN}")
Shared().process_manager.shutdown()
Shared().process_pool_executor.shutdown()
Shared().thread_terminate_signal.set()
Shared().process_terminate_signal.set()

process_manager = Shared()._process_manager
if process_manager is not None:
process_manager.shutdown()

process_pool_executor = Shared()._process_pool_executor
if process_pool_executor is not None:
process_pool_executor.shutdown(wait=False)

process = Shared()._process
if process is not None:
process.terminate()
process.join()

thread = Shared()._thread
if thread is not None:
thread.join()

await db.close()
await cache.close()
api_logger.critical("DB & CACHE connection closed!")
Expand Down
14 changes: 12 additions & 2 deletions app/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from pathlib import Path
from re import Pattern, compile
from typing import Optional
from urllib import parse

import requests
from aiohttp import ClientTimeout
from dotenv import load_dotenv
from urllib import parse

load_dotenv()

Expand Down Expand Up @@ -141,8 +143,11 @@ class Config(metaclass=SingletonMetaClass):
shared_vectorestore_name: str = QDRANT_COLLECTION
trusted_hosts: list[str] = field(default_factory=lambda: ["*"])
allowed_sites: list[str] = field(default_factory=lambda: ["*"])
llama_cpp_api_url: Optional[str] = "http://localhost:8002/v1/completions"

def __post_init__(self):
self.llama_cpp_available: bool = self.llama_cpp_api_url is not None
self.is_llama_cpp_booting: bool = False
if not DOCKER_MODE:
self.port = 8001
self.mysql_host = "localhost"
Expand Down Expand Up @@ -248,7 +253,12 @@ class ChatConfig:
timeout: ClientTimeout = ClientTimeout(sock_connect=30.0, sock_read=20.0)
read_timeout: float = 30.0 # wait for this time before timeout
wait_for_reconnect: float = 3.0 # wait for this time before reconnecting
api_regex_pattern: Pattern = compile(r"data:\s*({.+?})\n\n")
api_regex_pattern_openai: Pattern = compile(
r"data:\s*({.+?})\n\n"
) # regex pattern to extract json from openai api response
api_regex_pattern_llama_cpp: Pattern = compile(
r"data:\s*({.+?})\r\n\r\n"
) # regex pattern to extract json from llama cpp api response
extra_token_margin: int = (
512 # number of tokens to remove when tokens exceed token limit
)
Expand Down
4 changes: 2 additions & 2 deletions app/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class QueryTemplates:
"\n---\n"
"{context}"
"\n---\n"
"Answer the question in detail: {question}\n"
"Answer the question in as much detail as possible: {question}\n"
),
input_variables=["context", "question"],
template_format="f-string",
Expand All @@ -23,7 +23,7 @@ class QueryTemplates:
"{context}"
"\n---\n"
"Given the context information and not prior knowledge, "
"answer the question in detail: {question}\n"
"answer the question in as much detail as possible:: {question}\n"
),
input_variables=["context", "question"],
template_format="f-string",
Expand Down
Binary file added app/contents/llama_api.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 11 additions & 16 deletions app/database/schemas/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,12 @@
Mapped,
mapped_column,
)

from app.viewmodels.status import ApiKeyStatus, UserStatus
from .. import Base
from . import Mixin


class UserStatus(str, enum.Enum):
admin = "admin"
active = "active"
deleted = "deleted"
blocked = "blocked"


class ApiKeyStatus(str, enum.Enum):
active = "active"
stopped = "stopped"
deleted = "deleted"


class Users(Base, Mixin):
__tablename__ = "users"
status: Mapped[str] = mapped_column(Enum(UserStatus), default=UserStatus.active)
Expand All @@ -37,7 +26,9 @@ class Users(Base, Mixin):
phone_number: Mapped[str | None] = mapped_column(String(length=20))
profile_img: Mapped[str | None] = mapped_column(String(length=100))
marketing_agree: Mapped[bool] = mapped_column(Boolean, default=True)
api_keys: Mapped["ApiKeys"] = relationship(back_populates="users", cascade="all, delete-orphan", lazy=True)
api_keys: Mapped["ApiKeys"] = relationship(
back_populates="users", cascade="all, delete-orphan", lazy=True
)
# chat_rooms: Mapped["ChatRooms"] = relationship(back_populates="users", cascade="all, delete-orphan", lazy=True)
# chat_messages: Mapped["ChatMessages"] = relationship(
# back_populates="users", cascade="all, delete-orphan", lazy=True
Expand All @@ -56,12 +47,16 @@ class ApiKeys(Base, Mixin):
is_whitelisted: Mapped[bool] = mapped_column(default=False)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
users: Mapped["Users"] = relationship(back_populates="api_keys")
whitelists: Mapped["ApiWhiteLists"] = relationship(backref="api_keys", cascade="all, delete-orphan")
whitelists: Mapped["ApiWhiteLists"] = relationship(
backref="api_keys", cascade="all, delete-orphan"
)


class ApiWhiteLists(Base, Mixin):
__tablename__ = "api_whitelists"
api_key_id: Mapped[int] = mapped_column(Integer, ForeignKey("api_keys.id", ondelete="CASCADE"))
api_key_id: Mapped[int] = mapped_column(
Integer, ForeignKey("api_keys.id", ondelete="CASCADE")
)
ip_address: Mapped[str] = mapped_column(String(length=64))


Expand Down
8 changes: 4 additions & 4 deletions app/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pydantic import Field
from pydantic.main import BaseModel

from app.database.schemas.auth import UserStatus
from app.utils.date_utils import UTC
from app.viewmodels.status import UserStatus

JSON_TYPES = Union[int, float, str, bool, dict, list, None]

Expand Down Expand Up @@ -135,7 +135,7 @@ class Config:
orm_mode = True


class OpenAIChatMessage(BaseModel):
class APIChatMessage(BaseModel):
role: str
content: str

Expand All @@ -146,10 +146,10 @@ class Config:
class MessageHistory(BaseModel):
role: str
content: str
tokens: int
actual_role: str
tokens: int = 0
timestamp: int = Field(default_factory=UTC.timestamp)
uuid: str = Field(default_factory=lambda: uuid4().hex)
actual_role: Optional[str] = None
model_name: Optional[str] = None
summarized: Optional[str] = None
summarized_tokens: Optional[int] = None
Expand Down
33 changes: 22 additions & 11 deletions app/models/llm_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
from typing import Type
from typing import TYPE_CHECKING, Type

from tiktoken import Encoding, encoding_for_model
from transformers.models.llama import LlamaTokenizer as _LlamaTokenizer

from app.utils.chat.llama_cpp import LlamaTokenizerAdapter

from app.utils.logger import api_logger
from app.shared import Shared

if TYPE_CHECKING:
from tiktoken import Encoding
from app.utils.chat.text_generations._llama_cpp import LlamaTokenizerAdapter


class BaseTokenizer(ABC):
Expand Down Expand Up @@ -39,11 +38,16 @@ def split_text_on_tokens(
chunk_ids = input_ids[start_idx:cur_idx]
return splits

def get_chunk_of(self, text: str, tokens: int) -> str:
"""Split incoming text and return chunks."""
input_ids = self.encode(text)
return self.decode(input_ids[: min(tokens, len(input_ids))])


class OpenAITokenizer(BaseTokenizer):
def __init__(self, model_name: str):
self.model_name = model_name
self._tokenizer: Encoding | None = None
self._tokenizer: "Encoding" | None = None

def encode(self, message: str, /) -> list[int]:
return self.tokenizer.encode(message)
Expand All @@ -59,7 +63,9 @@ def vocab_size(self) -> int:
return self.tokenizer.n_vocab

@property
def tokenizer(self) -> Encoding:
def tokenizer(self) -> "Encoding":
from tiktoken import encoding_for_model

if self._tokenizer is None:
print("Loading tokenizer: ", self.model_name)
self._tokenizer = encoding_for_model(self.model_name)
Expand All @@ -69,7 +75,7 @@ def tokenizer(self) -> Encoding:
class LlamaTokenizer(BaseTokenizer):
def __init__(self, model_name: str):
self.model_name = model_name
self._tokenizer: Encoding | None = None
self._tokenizer: "Encoding" | None = None

def encode(self, message: str, /) -> list[int]:
return self.tokenizer.encode(message)
Expand All @@ -85,7 +91,9 @@ def vocab_size(self) -> int:
return self.tokenizer.n_vocab

@property
def tokenizer(self) -> Encoding:
def tokenizer(self) -> "Encoding":
from transformers.models.llama import LlamaTokenizer as _LlamaTokenizer

if self._tokenizer is None:
split_str = self.model_name.split("/")

Expand Down Expand Up @@ -118,6 +126,7 @@ def __init__(self, llama_cpp_model_name: str):
def encode(self, message: str, /) -> list[int]:
from app.models.llms import LLMModels
from app.models.llms import LlamaCppModel
from app.shared import Shared

llama_cpp_model = LLMModels.find_model_by_name(self.llama_cpp_model_name)
assert isinstance(llama_cpp_model, LlamaCppModel), type(llama_cpp_model)
Expand All @@ -135,5 +144,7 @@ def tokens_of(self, message: str) -> int:
return len(self.encode(message))

@property
def tokenizer(self) -> Type[LlamaTokenizerAdapter]:
def tokenizer(self) -> Type["LlamaTokenizerAdapter"]:
from app.utils.chat.text_generations._llama_cpp import LlamaTokenizerAdapter

return LlamaTokenizerAdapter
Loading

0 comments on commit 9b42933

Please sign in to comment.