Skip to content

Commit

Permalink
refactored cli args
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Sep 7, 2023
1 parent 02993c8 commit 678a13c
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 79 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ models/gptq/*
repositories/
*.log
*.pyc
*.csv
/**/_*
.venv/
.vscode/
Expand Down
42 changes: 31 additions & 11 deletions llama_api/server/app_settings.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import sys
from contextlib import asynccontextmanager
from os import environ, getpid
from pathlib import Path
from random import randint
import sys
from threading import Timer
from typing import Literal, Optional

from ..shared.config import AppSettingsCliArgs, MainCliArgs, Config

from ..shared.config import AppSettingsCliArgs, Config, MainCliArgs
from ..utils.dependency import (
get_installed_packages,
get_outdated_packages,
get_poetry_executable,
git_clone,
git_pull,
run_command,
install_all_dependencies,
install_package,
install_pytorch,
install_tensorflow,
run_command,
)
from ..utils.llama_cpp import build_shared_lib
from ..utils.logger import ApiLogger
Expand Down Expand Up @@ -80,10 +80,8 @@ def initialize_before_launch() -> None:
skip_tensorflow_install = args.skip_tf_install.value or False
skip_compile = args.skip_compile.value or False
no_cache_dir = args.no_cache_dir.value or False
print(
"Starting Application with CLI args:"
+ str(environ["LLAMA_API_ARGS"])
)

print(f"\033[37;46;1m{environ['LLAMA_API_ARGS']}\033[0m")

# PIP arguments
pip_args = [] # type: list[str]
Expand Down Expand Up @@ -124,12 +122,23 @@ def initialize_before_launch() -> None:
install_tensorflow(args=pip_args)

# Install all dependencies of our project and other repositories
project_paths = [Path(".")] + list(Path("repositories").glob("*"))
install_all_dependencies(project_paths=project_paths, args=pip_args)
install_all_dependencies(
project_paths=[Path(".")] + list(Path("repositories").glob("*")),
args=pip_args,
)

# Get current packages installed
logger.info(f"📦 Installed packages: {get_installed_packages()}")
else:
if upgrade:
outdated_packages = get_outdated_packages()
if outdated_packages:
logger.warning(
"📦 Upgrading outdated packages: " f"{outdated_packages}"
)
install_package(" ".join(outdated_packages), args=pip_args)
else:
logger.info("📦 All packages are up-to-date!")
logger.warning(
"🏃‍♂️ Skipping package installation... "
"If any packages are missing, "
Expand All @@ -142,12 +151,23 @@ def initialize_before_launch() -> None:
@asynccontextmanager
async def lifespan(app):
from ..utils.logger import ApiLogger
from ..utils.model_definition_finder import ModelDefinitions

model_mappings, oai_mappings = ModelDefinitions.get_model_mappings()
for oai_name, llama_name in oai_mappings.items():
if llama_name in model_mappings:
model_mappings[oai_name] = model_mappings[llama_name]
print(
"\n".join(
f"\033[34;47;1m{name}\033[0m\n{llm_model.repr()}"
for name, llm_model in model_mappings.items()
)
)
ApiLogger.cinfo("🦙 LLaMA API server is running")
try:
yield
finally:
from ..utils.concurrency import _pool, _manager
from ..utils.concurrency import _manager, _pool

if _manager is not None:
_manager.shutdown()
Expand Down
16 changes: 11 additions & 5 deletions llama_api/server/routers/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@
ChatCompletionContext = Tuple[
Request, CreateChatCompletionRequest, int, Queue, Event
]
CompletionContext = Tuple[Request, CreateCompletionRequest, int, Queue, Event]
CompletionContext = Tuple[
Request, CreateCompletionRequest, int, Queue, Event
]
EmbeddingContext = Tuple[Request, CreateEmbeddingRequest, int, Queue, Event]
T = TypeVar("T")

Expand All @@ -91,7 +93,7 @@ class WixHandler:
processing a request. This is used to prevent multiple requests from
creating multiple completion generators at the same time."""

wix_metas: Tuple[WixMetadata] = tuple(
wix_metas: Tuple[WixMetadata, ...] = tuple(
WixMetadata(wix) for wix in range(MAX_WORKERS)
)

Expand All @@ -113,7 +115,9 @@ def get_wix_meta(cls, request_key: Optional[str] = None) -> WixMetadata:
return cls.wix_metas[choice(candidates)]

@staticmethod
def _get_worker_rank(meta: WixMetadata, request_key: Optional[str]) -> int:
def _get_worker_rank(
meta: WixMetadata, request_key: Optional[str]
) -> int:
"""Get the entry rank for the worker index (wix) metadata.
Lower rank means higher priority of the worker to process the request.
If the rank is -2, then the worker is processing the same model
Expand All @@ -138,7 +142,9 @@ def validate_item_type(item: Any, type: Type[T]) -> T:
raise item
elif not isinstance(item, type):
# The producer task has returned an invalid response
raise TypeError(f"Expected type {type}, but got {type(item)} instead")
raise TypeError(
f"Expected type {type}, but got {type(item)} instead"
)
return item


Expand Down Expand Up @@ -348,6 +354,6 @@ async def get_models() -> ModelList:
owned_by="me",
permissions=[],
)
for model_name in ModelDefinitions.get_llm_model_names()
for model_name in ModelDefinitions.get_model_names()
],
)
34 changes: 34 additions & 0 deletions llama_api/shared/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class GitCloneArgs(TypedDict):


class AppSettingsCliArgs(CliArgHelper):
__description__ = (
"Settings for the server, and installation of dependencies"
)

install_pkgs: CliArg[bool] = CliArg(
type=bool,
action="store_true",
Expand Down Expand Up @@ -63,6 +67,9 @@ class AppSettingsCliArgs(CliArgHelper):


class MainCliArgs(AppSettingsCliArgs):
__description__ = (
"Main CLI arguments for the server, including app settings"
)
port: CliArg[int] = CliArg(
type=int,
short_option="p",
Expand Down Expand Up @@ -117,6 +124,7 @@ class MainCliArgs(AppSettingsCliArgs):


class ModelDownloaderCliArgs(CliArgHelper):
__description__ = "Download models from HuggingFace"
model: CliArgList[str] = CliArgList(
type=str,
n_args="+",
Expand Down Expand Up @@ -159,6 +167,32 @@ class ModelDownloaderCliArgs(CliArgHelper):
)


class LogParserCliArgs(CliArgHelper):
__description__ = "Process chat and debug logs."

min_output_length: CliArg[int] = CliArg(
type=int, default=30, help="Minimum length for the output."
)
chat_log_file_path: CliArg[str] = CliArg(
type=str,
default="logs/chat.log",
help="Path to the chat log file.",
)
debug_log_file_path: CliArg[str] = CliArg(
type=str,
default="logs/debug.log",
help="Path to the debug log file.",
)
ignore_messages_less_than: CliArg[int] = CliArg(
type=int, default=2, help="Ignore messages shorter than this length."
)
output_path: CliArg[str] = CliArg(
type=str,
default="./logs/chat.csv",
help="Path to save the extracted chats as CSV.",
)


class Config:
"""Configuration for the project"""

Expand Down
4 changes: 3 additions & 1 deletion llama_api/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class CliArgHelper:
"""Helper class for loading CLI arguments from environment variables
or a namespace of CLI arguments"""

__description__: Optional[str] = None

@classmethod
def load(
cls,
Expand Down Expand Up @@ -157,7 +159,7 @@ def iterate_over_cli_args(cls) -> Iterable[Tuple[str, CliArg]]:
@classmethod
def get_parser(cls) -> argparse.ArgumentParser:
"""Return an argument parser with all CLI arguments"""
arg_parser = argparse.ArgumentParser()
arg_parser = argparse.ArgumentParser(description=cls.__description__)
for cli_key, cli_arg in cls.iterate_over_cli_args():
args = [] # type: List[str]
if cli_arg.is_positional:
Expand Down
37 changes: 33 additions & 4 deletions llama_api/utils/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def run_command(
return False
else:
if verbose:
logger.info(f"{success_emoji} Successfully {name} {action}ed.")
logger.info(
f"{success_emoji} Successfully {name} {action}ed."
)
return True
except Exception as e:
if verbose:
Expand Down Expand Up @@ -166,7 +168,11 @@ def check_if_torch_version_available(

# Check if the CUDA version of torch is available
for line in urlopen(source).read().splitlines():
if package_ver in line and python_ver in line and platform in line:
if (
package_ver in line
and python_ver in line
and platform in line
):
return True
return False
except Exception:
Expand All @@ -188,7 +194,9 @@ def parse_requirements(
A list of parsed requirements.
"""
# Define the regular expression pattern
pattern = compile(r"([a-zA-Z0-9_\-\+]+)(==|>=|<=|~=|>|<|!=|===)([0-9\.]+)")
pattern = compile(
r"([a-zA-Z0-9_\-\+]+)(==|>=|<=|~=|>|<|!=|===)([0-9\.]+)"
)

# Use finditer to get all matches in the string
return [
Expand Down Expand Up @@ -388,6 +396,25 @@ def install_all_dependencies(
return result


def get_outdated_packages() -> List[str]:
return [
line.split("==")[0]
for line in run(
[
sys.executable,
"-m",
"pip",
"list",
"--outdated",
"--format=freeze",
],
capture_output=True,
text=True,
).stdout.splitlines()
if not line.startswith("-e")
]


def remove_all_dependencies():
"""Remove all dependencies.
To be used when cleaning up the environment."""
Expand All @@ -404,7 +431,9 @@ def remove_all_dependencies():

# Step 2: Uninstall all packages listed in the temp file
with open(temp_path, "r") as temp_file:
packages = [line.strip() for line in temp_file if "-e" not in line]
packages = [
line.strip() for line in temp_file if "-e" not in line
]

for package in packages:
# The "--yes" option automatically confirms the uninstallation
Expand Down
2 changes: 1 addition & 1 deletion llama_api/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def context_length_exceeded(
request: Union[
"CreateCompletionRequest", "CreateChatCompletionRequest"
],
match, # type: Match[str] # type: ignore
match, # type: Match[str]
) -> Tuple[int, ErrorResponse]:
"""Formatter for context length exceeded error"""

Expand Down
Loading

0 comments on commit 678a13c

Please sign in to comment.