Skip to content

Commit

Permalink
Modified process pool
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Jul 24, 2023
1 parent 59cb8d8 commit 91e06c1
Show file tree
Hide file tree
Showing 16 changed files with 1,060 additions and 277 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
models/*
models/ggml/*
models/gptq/*
!models/ggml/llama_cpp_models_here.txt
!models/gptq/exllama_models_here.txt
*.log
Expand Down
77 changes: 77 additions & 0 deletions llama_api/mixins/prompt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from ..schemas.api import APIChatMessage, TextGenerationSettings


class PromptUtilsMixin:
user_role: str = "user"
system_role: str = "system"
user_input_role: str = "User"
system_input_role: str = "System"
ai_fallback_input_role: str = "Assistant"

@staticmethod
def get_stop_strings(*roles: str) -> list[str]:
"""A helper method to generate stop strings for a given set of roles.
Stop strings are required to stop text completion API from generating
text that does not belong to the current chat turn.
e.g. The common stop string is "### USER:",
which can prevent ai from generating user's message itself."""

prompt_stop = set()
for role in roles:
avoids = (
f"{role}:",
f"### {role}:",
f"###{role}:",
)
prompt_stop.update(
avoids,
map(str.capitalize, avoids),
map(str.upper, avoids),
map(str.lower, avoids),
)
return list(prompt_stop)

@classmethod
def convert_messages_into_prompt(
cls, messages: list[APIChatMessage], settings: TextGenerationSettings
) -> str:
"""A helper method to convert list of messages into one text prompt."""

ai_input_role: str = cls.ai_fallback_input_role
chat_history: str = ""
for message in messages:
if message.role.lower() == cls.user_role:
input_role = cls.user_input_role
elif message.role.lower() == cls.system_role:
input_role = cls.system_input_role
else:
input_role = ai_input_role = message.role
chat_history += f"### {input_role}:{message.content}"

prompt_stop: list[str] = cls.get_stop_strings(
cls.user_input_role, cls.system_input_role, ai_input_role
)
if isinstance(settings.stop, str):
settings.stop = prompt_stop + [settings.stop]
elif isinstance(settings.stop, list):
settings.stop = prompt_stop + settings.stop
else:
settings.stop = prompt_stop
return chat_history + f"### {ai_input_role}:"

@staticmethod
def is_possible_to_generate_stops(
decoded_text: str, stops: list[str]
) -> bool:
"""A helper method to check if
the decoded text contains any of the stop tokens."""

for stop in stops:
if stop in decoded_text or any(
[
decoded_text.endswith(stop[: i + 1])
for i in range(len(stop))
]
):
return True
return False
30 changes: 30 additions & 0 deletions llama_api/mixins/waiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from threading import Event
from typing import Optional


class WaiterMixin:
_is_available: Optional[Event] = None

@property
def is_available(self) -> bool:
"""Check if the model is available."""
if self._is_available is None:
self._is_available = Event()
self._is_available.set()
return self._is_available.is_set()

def wait_until_available(self) -> None:
"""Wait until the model is available."""
if self._is_available is None:
self._is_available = Event()
self._is_available.set()
self._is_available.wait()

def set_availability(self, availablity: bool) -> None:
"""Set the model availability."""
if self._is_available is None:
self._is_available = Event()
if availablity:
self._is_available.set()
else:
self._is_available.clear()
80 changes: 3 additions & 77 deletions llama_api/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass
from typing import Any, Iterator, TypeVar

from ..mixins.prompt_utils import PromptUtilsMixin
from ..mixins.waiter import WaiterMixin
from ..schemas.api import (
APIChatMessage,
ChatCompletion,
Expand All @@ -22,17 +24,9 @@ class BaseLLMModel:
max_total_tokens: int = 2048


class BaseCompletionGenerator(ABC):
class BaseCompletionGenerator(ABC, PromptUtilsMixin, WaiterMixin):
"""Base class for all completion generators."""

user_role: str = "user"
system_role: str = "system"

user_input_role: str = "User"
system_input_role: str = "System"

ai_fallback_input_role: str = "Assistant"

@abstractmethod
def __del__(self):
"""Clean up resources."""
Expand Down Expand Up @@ -76,74 +70,6 @@ def generate_chat_completion_with_streaming(
yielding chunks of text as they are generated."""
...

@staticmethod
def get_stop_strings(*roles: str) -> list[str]:
"""A helper method to generate stop strings for a given set of roles.
Stop strings are required to stop text completion API from generating
text that does not belong to the current chat turn.
e.g. The common stop string is "### USER:",
which can prevent ai from generating user's message itself."""

prompt_stop = set()
for role in roles:
avoids = (
f"{role}:",
f"### {role}:",
f"###{role}:",
)
prompt_stop.update(
avoids,
map(str.capitalize, avoids),
map(str.upper, avoids),
map(str.lower, avoids),
)
return list(prompt_stop)

@classmethod
def convert_messages_into_prompt(
cls, messages: list[APIChatMessage], settings: TextGenerationSettings
) -> str:
"""A helper method to convert list of messages into one text prompt."""

ai_input_role: str = cls.ai_fallback_input_role
chat_history: str = ""
for message in messages:
if message.role.lower() == cls.user_role:
input_role = cls.user_input_role
elif message.role.lower() == cls.system_role:
input_role = cls.system_input_role
else:
input_role = ai_input_role = message.role
chat_history += f"### {input_role}:{message.content}"

prompt_stop: list[str] = cls.get_stop_strings(
cls.user_input_role, cls.system_input_role, ai_input_role
)
if isinstance(settings.stop, str):
settings.stop = prompt_stop + [settings.stop]
elif isinstance(settings.stop, list):
settings.stop = prompt_stop + settings.stop
else:
settings.stop = prompt_stop
return chat_history + f"### {ai_input_role}:"

@staticmethod
def is_possible_to_generate_stops(
decoded_text: str, stops: list[str]
) -> bool:
"""A helper method to check if
the decoded text contains any of the stop tokens."""

for stop in stops:
if stop in decoded_text or any(
[
decoded_text.endswith(stop[: i + 1])
for i in range(len(stop))
]
):
return True
return False

@property
@abstractmethod
def llm_model(self) -> "BaseLLMModel":
Expand Down
11 changes: 10 additions & 1 deletion llama_api/server/app_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import platform
import subprocess
from contextlib import asynccontextmanager
from os import environ
from typing import Optional


Expand Down Expand Up @@ -67,10 +68,12 @@ def initialize_before_launch(install_packages: bool = False):
@asynccontextmanager
async def lifespan(app):
from ..utils.logger import ApiLogger
from ..utils.concurrency import pool

ApiLogger.ccritical("🦙 LLaMA API server is running")
yield
ApiLogger.ccritical("🦙 Shutting down LLaMA API server...")
pool().kill()


def create_app_llama_cpp():
Expand All @@ -96,10 +99,16 @@ async def health():
return new_app


def run(port: int) -> None:
def run(
port: int,
max_workers: int = 1,
) -> None:
initialize_before_launch(install_packages=True)

from uvicorn import Config, Server

environ["MAX_WORKERS"] = str(max_workers)

Server(
config=Config(
create_app_llama_cpp(),
Expand Down
Loading

0 comments on commit 91e06c1

Please sign in to comment.