Skip to content

Commit

Permalink
fix: Fix memory allocation of ndarray (#1704)
Browse files Browse the repository at this point in the history
* Fix memory allocation of ndarray

* Add basic LlamaState tests

* Improve LlamaState test and fix rng / seed

---------

Co-authored-by: Andrei <[email protected]>
  • Loading branch information
xu-song and abetlen authored Sep 19, 2024
1 parent 9b64bb5 commit 22cedad
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 15 deletions.
34 changes: 19 additions & 15 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import ctypes
import typing
import random
import fnmatch
import warnings
import contextlib
Expand Down Expand Up @@ -301,9 +302,11 @@ def __init__(
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()

# Used by the sampler
self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED

# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed
self.context_params.n_ctx = n_ctx
self.context_params.n_batch = self.n_batch
self.context_params.n_threads = self.n_threads
Expand Down Expand Up @@ -613,8 +616,7 @@ def set_seed(self, seed: int):
Args:
seed: The random seed.
"""
# TODO: Fix this
# llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
self._seed = seed

def reset(self):
"""Reset the model state."""
Expand Down Expand Up @@ -672,7 +674,6 @@ def _init_sampler(
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
seed: Optional[int] = None,
):
sampler = internals.LlamaSampler()

Expand Down Expand Up @@ -715,22 +716,22 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):

if temp < 0.0:
sampler.add_softmax()
sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED)
sampler.add_dist(self._seed)
elif temp == 0.0:
sampler.add_greedy()
else:
if mirostat_mode == 1:
mirostat_m = 100
sampler.add_mirostat(
self._n_vocab,
seed or llama_cpp.LLAMA_DEFAULT_SEED,
self._seed,
mirostat_tau,
mirostat_eta,
mirostat_m,
)
elif mirostat_mode == 2:
sampler.add_mirostat_v2(
seed or llama_cpp.LLAMA_DEFAULT_SEED,
self._seed,
mirostat_tau,
mirostat_eta,
)
Expand All @@ -743,7 +744,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
sampler.add_top_p(top_p, min_keep)
sampler.add_min_p(min_p, min_keep)
sampler.add_temp(temp)
sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED)
sampler.add_dist(self._seed)
return sampler

def sample(
Expand Down Expand Up @@ -826,7 +827,6 @@ def generate(
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None,
seed: Optional[int] = None,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
Expand Down Expand Up @@ -865,7 +865,6 @@ def generate(
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
seed=seed,
)

# Check for kv cache prefix match
Expand Down Expand Up @@ -1301,9 +1300,10 @@ def logit_bias_processor(
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)

# TODO: Fix this
# if seed is not None:
# self._ctx.set_rng_seed(seed)
if seed is not None:
self.set_seed(seed)
else:
self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))

finish_reason = "length"
multibyte_fix = 0
Expand All @@ -1324,7 +1324,6 @@ def logit_bias_processor(
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
grammar=grammar,
seed=seed,
):
if llama_cpp.llama_token_is_eog(self._model.model, token):
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
Expand Down Expand Up @@ -2136,14 +2135,17 @@ def save_state(self) -> LlamaState:
n_tokens=self.n_tokens,
llama_state=bytes(llama_state_compact),
llama_state_size=n_bytes,
seed=self._seed,
)

def load_state(self, state: LlamaState) -> None:
# Only filling in up to `n_tokens` and then zero-ing out the rest
self.scores[: state.n_tokens, :] = state.scores.copy()
self.scores[state.n_tokens :, :] = 0.0
rest = self.scores[state.n_tokens :, :]
rest[rest > 0] = 0.0
self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens
self._seed = state.seed
state_size = state.llama_state_size
LLamaStateArrayType = ctypes.c_uint8 * state_size
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
Expand Down Expand Up @@ -2321,12 +2323,14 @@ def __init__(
n_tokens: int,
llama_state: bytes,
llama_state_size: int,
seed: int,
):
self.input_ids = input_ids
self.scores = scores
self.n_tokens = n_tokens
self.llama_state = llama_state
self.llama_state_size = llama_state_size
self.seed = seed


LogitsProcessor = Callable[
Expand Down
45 changes: 45 additions & 0 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,48 @@ def logit_processor_func(input_ids, logits):
logits_processor=logit_processors
)
assert output["choices"][0]["text"].lower().startswith("rot")

model.set_seed(1337)

state = model.save_state()

output = model.create_completion(
"Pick a number from 1 to 10?:\n",
max_tokens=4,
top_k=50,
top_p=0.9,
temperature=0.8,
grammar=llama_cpp.LlamaGrammar.from_string("""
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
""")
)
number_1 = output["choices"][0]["text"]

output = model.create_completion(
"Pick a number from 1 to 10?:\n",
max_tokens=4,
top_k=50,
top_p=0.9,
temperature=0.8,
grammar=llama_cpp.LlamaGrammar.from_string("""
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
""")
)
number_2 = output["choices"][0]["text"]

model.load_state(state)

output = model.create_completion(
"Pick a number from 1 to 10?:\n",
max_tokens=4,
top_k=50,
top_p=0.9,
temperature=0.8,
grammar=llama_cpp.LlamaGrammar.from_string("""
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
""")
)
number_3 = output["choices"][0]["text"]

assert number_1 != number_2
assert number_1 == number_3

0 comments on commit 22cedad

Please sign in to comment.