Skip to content

Commit

Permalink
add eos_id_list to llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
youth123 committed Jun 24, 2024
1 parent 4b65b64 commit 3a4d579
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 55 deletions.
26 changes: 21 additions & 5 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2417,14 +2417,21 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
}

const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
for (int32_t i = 0; i < n_eos; ++i) {
params.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
}
}

if (params.warmup) {
LOG("warming up the model with an empty run\n");

std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
std::vector<llama_token> tmp = { llama_token_bos(model) };
tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end());
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
Expand Down Expand Up @@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);

const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
bool ignore_eos = false;
for (auto eos: eos_tokens) {
const auto logit_bias_eos = sparams.logit_bias.find(eos);
if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) {
ignore_eos = true;
}
}
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");

yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
Expand All @@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l

fprintf(stream, "logit_bias:\n");
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
if (ignore_eos && lb.first == logit_bias_eos->first) {
if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) {
continue;
}
fprintf(stream, " %d: %f", lb.first, lb.second);
Expand Down
6 changes: 5 additions & 1 deletion common/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ int64_t get_example_targets_batch(

ggml_set_f32(target_probs, 0.0f);
llama_token bos = llama_token_bos(llama_get_model(lctx));
llama_token eos = llama_token_eos(llama_get_model(lctx));
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
llama_token eos = eos_ptr[0];
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);
Expand Down
9 changes: 5 additions & 4 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_pad_token_id(3)
self.gguf_writer.add_eos_token_id(1)
self.gguf_writer.add_eos_token_id_list([1])
self.gguf_writer.add_unk_token_id(0)

def set_gguf_parameters(self):
Expand Down Expand Up @@ -2339,8 +2339,8 @@ def set_vocab(self):
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0] if field else 1)

field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0] if field else 0)
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID_LIST)
self.gguf_writer.add_eos_token_id_list([field.parts[-1].tolist()[0] if field else 0])

field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0] if field else 0)
Expand Down Expand Up @@ -2875,9 +2875,10 @@ def set_vocab(self):
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_eos_token_id_list([151329, 151336, 151338])

special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.chat_template = "ChatGLM4"
special_vocab.chat_template = "chatglm4"
special_vocab.merges = merges
# only add special tokens when they were not already loaded from config.json
# if len(special_vocab.special_token_ids) == 0:
Expand Down
2 changes: 1 addition & 1 deletion convert-llama-ggml-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def add_vocab(self, gguf_writer):
gguf_writer.add_token_types(toktypes)
gguf_writer.add_unk_token_id(0)
gguf_writer.add_bos_token_id(1)
gguf_writer.add_eos_token_id(2)
gguf_writer.add_eos_token_id_list([2])

def add_tensors(self, gguf_writer):
tensor_map = self.name_map
Expand Down
3 changes: 1 addition & 2 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
std::string result;

const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);

llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, true);
Expand Down Expand Up @@ -123,7 +122,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };

llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == eos_token) {
if (llama_token_is_eog(mdl, token)) {
break;
}

Expand Down
9 changes: 7 additions & 2 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,13 @@ int main(int argc, char ** argv) {
return 1;
}
// add eos if not present
if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) {
inp.push_back(llama_token_eos(model));
const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);

if (!eos_tokens.empty() && (inp.empty() || std::count(eos_tokens.begin(), eos_tokens.end(), inp.back()))) {
inp.insert(inp.end(), eos_tokens.begin(), eos_tokens.end());
}
chunk.tokens = inp;
}
Expand Down
22 changes: 18 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,13 @@ struct server_context {
slot.sparams.logit_bias.clear();

if (json_value(data, "ignore_eos", false)) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);
for (int32_t i = 0; i < n_eos; ++i) {
slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
}
}

const auto & logit_bias = data.find("logit_bias");
Expand Down Expand Up @@ -1308,9 +1314,17 @@ struct server_context {
}

json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);

const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);
bool ignore_eos = false;
for (auto eos: eos_tokens) {
const auto logit_bias_eos = slot.sparams.logit_bias.find(eos);
if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) {
ignore_eos = true;
}
}
std::vector<std::string> samplers_sequence;
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
Expand Down
11 changes: 10 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,21 @@ int main(int argc, char ** argv) {
fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return 1;
}
const int n_eos_tgt = llama_n_eos(model_tgt);
std::vector<int32_t> eos_tokens_tgt(n_eos_tgt, 0);
int32_t* eos_ptr_tgt = eos_tokens_tgt.data();
llama_token_eos(model_tgt, eos_ptr_tgt);

const int n_eos_dft = llama_n_eos(model_dft);
std::vector<int32_t> eos_tokens_dft(n_eos_dft, 0);
int32_t* eos_ptr_dft = eos_tokens_dft.data();
llama_token_eos(model_dft, eos_ptr_dft);

if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
eos_tokens_tgt != eos_tokens_dft
) {
fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
return 1;
Expand Down
6 changes: 4 additions & 2 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class Tokenizer:
SCORES = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id" # recommand eos_id_list
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
Expand All @@ -107,6 +107,8 @@ class Tokenizer:
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
EOS_ID_LIST = "tokenizer.ggml.eos_token_id_list"



#
Expand Down Expand Up @@ -1091,7 +1093,7 @@ def get_type(val: Any) -> GGUFValueType:
KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES
KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID
KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID
KEY_TOKENIZER_EOS_ID_LIST= Keys.Tokenizer.EOS_ID_LIST
KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID
Expand Down
6 changes: 3 additions & 3 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,9 @@ def add_token_scores(self, scores: Sequence[float]) -> None:

def add_bos_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.BOS_ID, id)

def add_eos_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOS_ID, id)
def add_eos_token_id_list(self, id: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
self.add_array(Keys.Tokenizer.EOS_ID_LIST, id)

def add_unk_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.UNK_ID, id)
Expand Down
Loading

0 comments on commit 3a4d579

Please sign in to comment.