Skip to content

Commit

Permalink
misc: Format
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Sep 19, 2024
1 parent 1e64664 commit 9b64bb5
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 80 deletions.
66 changes: 50 additions & 16 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def n_params(self) -> int:
def get_tensor(self, name: str) -> ctypes.c_void_p:
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))


# Vocab

def token_get_text(self, token: int) -> str:
Expand Down Expand Up @@ -460,9 +459,7 @@ def __init__(
self.verbose = verbose
self._exit_stack = ExitStack()

batch = llama_cpp.llama_batch_init(
self._n_tokens, self.embd, self.n_seq_max
)
batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)

if batch is None:
raise ValueError("Failed to create llama_batch")
Expand Down Expand Up @@ -541,6 +538,7 @@ def copy_logits(self, logits: npt.NDArray[np.single]):

# Embedding functions


def normalize_embedding(embedding):
norm = float(np.linalg.norm(embedding))
if norm == 0.0:
Expand Down Expand Up @@ -713,11 +711,17 @@ def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
import ctypes
import llama_cpp


class CustomSampler:
def __init__(self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]):
def __init__(
self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]
):
self.apply_func = apply_func

def apply_wrapper(sampler: llama_cpp.llama_sampler_p, cur_p: llama_cpp.llama_token_data_array_p):
def apply_wrapper(
sampler: llama_cpp.llama_sampler_p,
cur_p: llama_cpp.llama_token_data_array_p,
):
self.apply_func(cur_p)

def free_wrapper(sampler: llama_cpp.llama_sampler_p):
Expand All @@ -740,6 +744,7 @@ def free_wrapper(sampler: llama_cpp.llama_sampler_p):
def get_sampler(self) -> llama_cpp.llama_sampler_p:
return ctypes.pointer(self.sampler)


class LlamaSampler:
def __init__(self):
params = llama_cpp.llama_sampler_chain_params()
Expand Down Expand Up @@ -788,33 +793,62 @@ def add_temp_ext(self, t: float, delta: float, exponent: float):
self._add_sampler(sampler)

def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
sampler = llama_cpp.llama_sampler_init_mirostat(
n_vocab, seed, tau, eta, m
)
sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
self._add_sampler(sampler)

def add_mirostat_v2(self, seed: int, tau: float, eta: float):
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
self._add_sampler(sampler)

def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
sampler = llama_cpp.llama_sampler_init_grammar(model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8"))
sampler = llama_cpp.llama_sampler_init_grammar(
model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
)
self._add_sampler(sampler)

def add_penalties(self, n_vocab: int, special_eos_id: int, linefeed_id: int, penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, penalize_nl: bool, ignore_eos: bool):
sampler = llama_cpp.llama_sampler_init_penalties(n_vocab, special_eos_id, linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos)
def add_penalties(
self,
n_vocab: int,
special_eos_id: int,
linefeed_id: int,
penalty_last_n: int,
penalty_repeat: float,
penalty_freq: float,
penalty_present: float,
penalize_nl: bool,
ignore_eos: bool,
):
sampler = llama_cpp.llama_sampler_init_penalties(
n_vocab,
special_eos_id,
linefeed_id,
penalty_last_n,
penalty_repeat,
penalty_freq,
penalty_present,
penalize_nl,
ignore_eos,
)
self._add_sampler(sampler)

def init_logit_bias(self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p):
sampler = llama_cpp.llama_sampler_init_logit_bias(n_vocab, n_logit_bias, logit_bias)
def init_logit_bias(
self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p
):
sampler = llama_cpp.llama_sampler_init_logit_bias(
n_vocab, n_logit_bias, logit_bias
)
self._add_sampler(sampler)

def add_custom(self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]):
def add_custom(
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
):
custom_sampler = CustomSampler(apply_func)
sampler = custom_sampler.get_sampler()
self._add_sampler(sampler)
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
self.custom_samplers.append((llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler))
self.custom_samplers.append(
(llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
)

def _add_sampler(self, sampler: llama_cpp.llama_sampler_p):
assert self.sampler is not None
Expand Down
98 changes: 56 additions & 42 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,28 +255,28 @@ def __init__(
for i, (k, v) in enumerate(kv_overrides.items()):
self._kv_overrides_array[i].key = k.encode("utf-8")
if isinstance(v, bool):
self._kv_overrides_array[i].tag = (
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
)
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
self._kv_overrides_array[i].value.val_bool = v
elif isinstance(v, int):
self._kv_overrides_array[i].tag = (
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
)
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
self._kv_overrides_array[i].value.val_i64 = v
elif isinstance(v, float):
self._kv_overrides_array[i].tag = (
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
)
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
self._kv_overrides_array[i].value.val_f64 = v
elif isinstance(v, str): # type: ignore
v_bytes = v.encode("utf-8")
if len(v_bytes) > 128: # TODO: Make this a constant
raise ValueError(f"Value for {k} is too long: {v}")
v_bytes = v_bytes.ljust(128, b"\0")
self._kv_overrides_array[i].tag = (
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
)
self._kv_overrides_array[
i
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
# copy min(v_bytes, 128) to str_value
address = typing.cast(
int,
Expand All @@ -292,9 +292,9 @@ def __init__(
else:
raise ValueError(f"Unknown value type for {k}: {v}")

self._kv_overrides_array[-1].key = (
b"\0" # ensure sentinel element is zeroed
)
self._kv_overrides_array[
-1
].key = b"\0" # ensure sentinel element is zeroed
self.model_params.kv_overrides = self._kv_overrides_array

self.n_batch = min(n_ctx, n_batch) # ???
Expand Down Expand Up @@ -431,9 +431,9 @@ def free_lora_adapter():

self.chat_format = chat_format
self.chat_handler = chat_handler
self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = (
{}
)
self._chat_handlers: Dict[
str, llama_chat_format.LlamaChatCompletionHandler
] = {}

self.draft_model = draft_model

Expand Down Expand Up @@ -580,7 +580,10 @@ def tokenize(
return self.tokenizer_.tokenize(text, add_bos, special)

def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
"""Detokenize a list of tokens.
Expand All @@ -592,7 +595,9 @@ def detokenize(
Returns:
The detokenized string.
"""
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special)
return self.tokenizer_.detokenize(
tokens, prev_tokens=prev_tokens, special=special
)

def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
Expand Down Expand Up @@ -681,12 +686,16 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
recarray = np.recarray(
shape=(size,),
dtype=np.dtype(
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
[("id", np.intc), ("logit", np.single), ("p", np.single)],
align=True,
),
buf=(llama_cpp.llama_token_data * size).from_address(
data_soa_address
),
buf=(llama_cpp.llama_token_data * size).from_address(data_soa_address),
)
for logit_processor in logits_processor:
recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)

sampler.add_custom(apply_func)

sampler.add_penalties(
Expand All @@ -698,7 +707,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
penalty_freq=frequency_penalty,
penalty_present=presence_penalty,
penalize_nl=penalize_nl,
ignore_eos=False
ignore_eos=False,
)

if grammar is not None:
Expand Down Expand Up @@ -841,22 +850,22 @@ def generate(
# Reset mirostat sampling
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
self._sampler = self._init_sampler(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
seed=seed,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
seed=seed,
)

# Check for kv cache prefix match
Expand All @@ -872,8 +881,11 @@ def generate(
tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix
if self.verbose:
print(f"Llama.generate: {longest_prefix} prefix-match hit, "
f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr)
print(
f"Llama.generate: {longest_prefix} prefix-match hit, "
f"remaining {len(tokens)} prompt tokens to eval",
file=sys.stderr,
)

# Reset the model state
if reset:
Expand Down Expand Up @@ -1032,7 +1044,9 @@ def decode_batch(seq_sizes: List[int]):
for j in range(size)
]
if normalize:
embedding = [internals.normalize_embedding(e) for e in embedding]
embedding = [
internals.normalize_embedding(e) for e in embedding
]
data.append(embedding)
pos += size
else:
Expand Down
6 changes: 3 additions & 3 deletions llama_cpp/llama_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class LlamaRAMCache(BaseLlamaCache):
def __init__(self, capacity_bytes: int = (2 << 30)):
super().__init__(capacity_bytes)
self.capacity_bytes = capacity_bytes
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = (
OrderedDict()
)
self.cache_state: OrderedDict[
Tuple[int, ...], "llama_cpp.llama.LlamaState"
] = OrderedDict()

@property
def cache_size(self):
Expand Down
3 changes: 2 additions & 1 deletion llama_cpp/llama_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

LLAMA_GRAMMAR_DEFAULT_ROOT = "root"


class LlamaGrammar:
def __init__(self, *args, _grammar: str, **kwargs):
self._grammar = _grammar
Expand All @@ -23,7 +24,7 @@ def __init__(self, *args, _grammar: str, **kwargs):
@classmethod
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
return cls(_grammar=grammar)

@classmethod
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
try:
Expand Down
33 changes: 22 additions & 11 deletions llama_cpp/llama_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def tokenize(

@abc.abstractmethod
def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
"""Detokenize the tokens into text.
Expand All @@ -49,7 +52,10 @@ def tokenize(
return self._model.tokenize(text, add_bos=add_bos, special=special)

def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
return self._model.detokenize(tokens, special=special)

Expand Down Expand Up @@ -80,19 +86,24 @@ def tokenize(
)

def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
skip_special_tokens = not special
skip_special_tokens = not special
if prev_tokens is not None:
text = self.hf_tokenizer.decode(prev_tokens + tokens, skip_special_tokens=skip_special_tokens).encode(
"utf-8", errors="ignore"
)
prev_text = self.hf_tokenizer.decode(prev_tokens, skip_special_tokens=skip_special_tokens).encode(
"utf-8", errors="ignore"
)
text = self.hf_tokenizer.decode(
prev_tokens + tokens, skip_special_tokens=skip_special_tokens
).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(
prev_tokens, skip_special_tokens=skip_special_tokens
).encode("utf-8", errors="ignore")
return text[len(prev_text) :]
else:
return self.hf_tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens).encode("utf-8", errors="ignore")
return self.hf_tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens
).encode("utf-8", errors="ignore")

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
Expand Down
Loading

0 comments on commit 9b64bb5

Please sign in to comment.