Skip to content

Commit

Permalink
exllama: supports loading sharded model
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Sep 9, 2023
1 parent 2b299fc commit 461eaf1
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 190 deletions.
91 changes: 57 additions & 34 deletions llama_api/modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,29 +217,34 @@ def _generate_text(
for _ in range(settings.max_tokens):
# If the generator was interrupted, stop the generation
if self.check_interruption(completion_status):
break
return

# Predict next token id
token_id = (
_gen_single_token_with_cfg(
generator=generator,
mask=cfg_mask,
cfg_alpha=settings.guidance_scale,
)
if cfg_mask is not None
else _gen_single_token_without_cfg(
generator=generator,
input_ids=generator.sequence[0][initial_len:],
logit_processors=logit_processors,
)
) # type: int

try:
token_id = (
_gen_single_token_with_cfg(
generator=generator,
mask=cfg_mask,
cfg_alpha=settings.guidance_scale,
)
if cfg_mask is not None
else _gen_single_token_without_cfg(
generator=generator,
input_ids=generator.sequence[0][initial_len:],
logit_processors=logit_processors,
)
) # type: int
except RuntimeError as e:
if "exceeds dimension size" in str(e):
logger.warning(f"Ignoring ExLlama RuntimeError: {e}")
return
raise e
# Check if the token is a stop token
if (
self.check_interruption(completion_status)
or token_id == eos_token_id
):
break
return

# Update the completion status
completion_status.generated_tokens += 1
Expand All @@ -265,7 +270,7 @@ def _generate_text(
completion_status.generated_text += text_to_yield
yield text_to_yield
elif stop_status is True: # Contains any of the stop tokens
break # Stop generating
return # Stop generating
else: # Contains any piece of the stop tokens
text_buffer = text_to_yield # Save the buffer

Expand All @@ -275,26 +280,44 @@ def _make_config(
) -> ExLlamaConfig:
"""Create a config object for the ExLlama model."""

# Find the model checkpoint
model_file_found: List[Path] = []
for ext in (".safetensors", ".pt", ".bin"):
model_file_found.extend(model_folder_path.glob(f"*{ext}"))
if model_file_found:
if len(model_file_found) > 1:
logger.warning(
f"More than one {ext} model has been found. "
"The last one will be selected. It could be wrong."
)
# Find the model checkpoint file and remove numbers from file names
remove_numbers_pattern = compile(r"\d+")
grouped_by_base_name = {} # type: dict[str, list[Path]]
for model_file in (
list(model_folder_path.glob("*.safetensors"))
or list(model_folder_path.glob("*.pt"))
or list(model_folder_path.glob("*.bin"))
):
grouped_by_base_name.setdefault(
remove_numbers_pattern.sub("", model_file.name), []
).append(model_file)

# Load required parameters
config = ExLlamaConfig((model_folder_path / "config.json").as_posix())

break
if not model_file_found:
# Choose the group with maximum files having the same base name after removing numbers
max_group = max(grouped_by_base_name.values(), key=len, default=[])
if len(max_group) == 1:
# If there is only one file in the group, use the largest file among all groups with a single file
model_path = max(
(
group[0]
for group in grouped_by_base_name.values()
if len(group) == 1
),
key=lambda x: x.stat().st_size,
).as_posix()
elif len(max_group) > 1:
# If there are multiple files in the group, use all of them as the model path
model_path = [model_file.as_posix() for model_file in max_group]
else:
# If there is no file in the group, raise an error
raise FileNotFoundError(
f"No model has been found in {model_folder_path}."
)

# Required parameters
config = ExLlamaConfig((model_folder_path / "config.json").as_posix())
config.model_path = model_file_found[-1].as_posix() # type: ignore
config.model_path = ( # type: Union[str, List[str]] # type: ignore
model_path
)
config.max_seq_len = llm_model.max_total_tokens
config.max_input_len = llm_model.max_total_tokens
config.max_attention_size = 2048**2
Expand Down Expand Up @@ -322,7 +345,7 @@ def _make_config(
logger.info(
f"Rotary embedding base has been set to {config.rotary_embedding_base}"
)

# For ROCm (AMD GPUs)
if version.hip:
config.rmsnorm_no_half2 = True
Expand Down
9 changes: 7 additions & 2 deletions llama_api/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Logger module for the API"""
# flake8: noqa
from contextlib import contextmanager
from datetime import date
import logging
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -14,8 +15,12 @@ class LoggingConfig:
logger_level: int = logging.DEBUG
console_log_level: int = logging.INFO
file_log_level: Optional[int] = logging.DEBUG
file_log_name: Optional[str] = "./logs/debug.log"
logging_format: str = "[%(asctime)s] %(name)s:%(levelname)s - %(message)s"
file_log_name: Optional[
str
] = f"./logs/{date.today().strftime('%Y-%m-%d')}-debug.log"
logging_format: str = (
"[%(asctime)s] %(name)s:%(levelname)s - %(message)s"
)
color: bool = True


Expand Down
Loading

0 comments on commit 461eaf1

Please sign in to comment.