Skip to content

Commit 3a4d579

Browse files
committed
add eos_id_list to llama.cpp
1 parent 4b65b64 commit 3a4d579

File tree

13 files changed

+122
-55
lines changed

13 files changed

+122
-55
lines changed

common/common.cpp

+21-5
Original file line numberDiff line numberDiff line change
@@ -2417,14 +2417,21 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
24172417
}
24182418
}
24192419

2420+
const int n_eos = llama_n_eos(llama_get_model(lctx));
2421+
std::vector<int32_t> eos_tokens(n_eos, 0);
2422+
int32_t* eos_ptr = eos_tokens.data();
2423+
llama_token_eos(llama_get_model(lctx), eos_ptr);
24202424
if (params.ignore_eos) {
2421-
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
2425+
for (int32_t i = 0; i < n_eos; ++i) {
2426+
params.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
2427+
}
24222428
}
24232429

24242430
if (params.warmup) {
24252431
LOG("warming up the model with an empty run\n");
24262432

2427-
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
2433+
std::vector<llama_token> tmp = { llama_token_bos(model) };
2434+
tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end());
24282435
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
24292436
llama_kv_cache_clear(lctx);
24302437
llama_synchronize(lctx);
@@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
33573364
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
33583365
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
33593366

3360-
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
3361-
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
3367+
const int n_eos = llama_n_eos(llama_get_model(lctx));
3368+
std::vector<int32_t> eos_tokens(n_eos, 0);
3369+
int32_t* eos_ptr = eos_tokens.data();
3370+
llama_token_eos(llama_get_model(lctx), eos_ptr);
3371+
bool ignore_eos = false;
3372+
for (auto eos: eos_tokens) {
3373+
const auto logit_bias_eos = sparams.logit_bias.find(eos);
3374+
if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) {
3375+
ignore_eos = true;
3376+
}
3377+
}
33623378
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
33633379

33643380
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
@@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
33713387

33723388
fprintf(stream, "logit_bias:\n");
33733389
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
3374-
if (ignore_eos && lb.first == logit_bias_eos->first) {
3390+
if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) {
33753391
continue;
33763392
}
33773393
fprintf(stream, " %d: %f", lb.first, lb.second);

common/train.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,11 @@ int64_t get_example_targets_batch(
240240

241241
ggml_set_f32(target_probs, 0.0f);
242242
llama_token bos = llama_token_bos(llama_get_model(lctx));
243-
llama_token eos = llama_token_eos(llama_get_model(lctx));
243+
const int n_eos = llama_n_eos(llama_get_model(lctx));
244+
std::vector<int32_t> eos_tokens(n_eos, 0);
245+
int32_t* eos_ptr = eos_tokens.data();
246+
llama_token_eos(llama_get_model(lctx), eos_ptr);
247+
llama_token eos = eos_ptr[0];
244248
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
245249
for (int k=0; k<n_batch; ++k) {
246250
// printf("%s: batch %d\n", __func__, k);

convert-hf-to-gguf.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def set_vocab(self):
801801
self._set_vocab_sentencepiece()
802802
self.gguf_writer.add_add_bos_token(False)
803803
self.gguf_writer.add_pad_token_id(3)
804-
self.gguf_writer.add_eos_token_id(1)
804+
self.gguf_writer.add_eos_token_id_list([1])
805805
self.gguf_writer.add_unk_token_id(0)
806806

807807
def set_gguf_parameters(self):
@@ -2339,8 +2339,8 @@ def set_vocab(self):
23392339
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
23402340
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0] if field else 1)
23412341

2342-
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
2343-
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0] if field else 0)
2342+
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID_LIST)
2343+
self.gguf_writer.add_eos_token_id_list([field.parts[-1].tolist()[0] if field else 0])
23442344

23452345
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
23462346
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0] if field else 0)
@@ -2875,9 +2875,10 @@ def set_vocab(self):
28752875
self.gguf_writer.add_tokenizer_pre(tokpre)
28762876
self.gguf_writer.add_token_list(tokens)
28772877
self.gguf_writer.add_token_types(toktypes)
2878+
self.gguf_writer.add_eos_token_id_list([151329, 151336, 151338])
28782879

28792880
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
2880-
special_vocab.chat_template = "ChatGLM4"
2881+
special_vocab.chat_template = "chatglm4"
28812882
special_vocab.merges = merges
28822883
# only add special tokens when they were not already loaded from config.json
28832884
# if len(special_vocab.special_token_ids) == 0:

convert-llama-ggml-to-gguf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def add_vocab(self, gguf_writer):
331331
gguf_writer.add_token_types(toktypes)
332332
gguf_writer.add_unk_token_id(0)
333333
gguf_writer.add_bos_token_id(1)
334-
gguf_writer.add_eos_token_id(2)
334+
gguf_writer.add_eos_token_id_list([2])
335335

336336
def add_tensors(self, gguf_writer):
337337
tensor_map = self.name_map

examples/gritlm/gritlm.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
9595
std::string result;
9696

9797
const llama_model * mdl = llama_get_model(ctx);
98-
llama_token eos_token = llama_token_eos(mdl);
9998

10099
llama_kv_cache_clear(ctx);
101100
llama_set_causal_attn(ctx, true);
@@ -123,7 +122,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
123122
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
124123

125124
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
126-
if (token == eos_token) {
125+
if (llama_token_is_eog(mdl, token)) {
127126
break;
128127
}
129128

examples/retrieval/retrieval.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,13 @@ int main(int argc, char ** argv) {
184184
return 1;
185185
}
186186
// add eos if not present
187-
if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) {
188-
inp.push_back(llama_token_eos(model));
187+
const int n_eos = llama_n_eos(model);
188+
std::vector<int32_t> eos_tokens(n_eos, 0);
189+
int32_t* eos_ptr = eos_tokens.data();
190+
llama_token_eos(model, eos_ptr);
191+
192+
if (!eos_tokens.empty() && (inp.empty() || std::count(eos_tokens.begin(), eos_tokens.end(), inp.back()))) {
193+
inp.insert(inp.end(), eos_tokens.begin(), eos_tokens.end());
189194
}
190195
chunk.tokens = inp;
191196
}

examples/server/server.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,13 @@ struct server_context {
10211021
slot.sparams.logit_bias.clear();
10221022

10231023
if (json_value(data, "ignore_eos", false)) {
1024-
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
1024+
const int n_eos = llama_n_eos(model);
1025+
std::vector<int32_t> eos_tokens(n_eos, 0);
1026+
int32_t* eos_ptr = eos_tokens.data();
1027+
llama_token_eos(model, eos_ptr);
1028+
for (int32_t i = 0; i < n_eos; ++i) {
1029+
slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
1030+
}
10251031
}
10261032

10271033
const auto & logit_bias = data.find("logit_bias");
@@ -1308,9 +1314,17 @@ struct server_context {
13081314
}
13091315

13101316
json get_formated_generation(const server_slot & slot) const {
1311-
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
1312-
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
1313-
1317+
const int n_eos = llama_n_eos(model);
1318+
std::vector<int32_t> eos_tokens(n_eos, 0);
1319+
int32_t* eos_ptr = eos_tokens.data();
1320+
llama_token_eos(model, eos_ptr);
1321+
bool ignore_eos = false;
1322+
for (auto eos: eos_tokens) {
1323+
const auto logit_bias_eos = slot.sparams.logit_bias.find(eos);
1324+
if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) {
1325+
ignore_eos = true;
1326+
}
1327+
}
13141328
std::vector<std::string> samplers_sequence;
13151329
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
13161330
for (const auto & sampler_type : slot.sparams.samplers_sequence) {

examples/speculative/speculative.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,21 @@ int main(int argc, char ** argv) {
8888
fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
8989
return 1;
9090
}
91+
const int n_eos_tgt = llama_n_eos(model_tgt);
92+
std::vector<int32_t> eos_tokens_tgt(n_eos_tgt, 0);
93+
int32_t* eos_ptr_tgt = eos_tokens_tgt.data();
94+
llama_token_eos(model_tgt, eos_ptr_tgt);
95+
96+
const int n_eos_dft = llama_n_eos(model_dft);
97+
std::vector<int32_t> eos_tokens_dft(n_eos_dft, 0);
98+
int32_t* eos_ptr_dft = eos_tokens_dft.data();
99+
llama_token_eos(model_dft, eos_ptr_dft);
91100

92101
if (
93102
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
94103
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
95104
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
96-
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
105+
eos_tokens_tgt != eos_tokens_dft
97106
) {
98107
fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
99108
return 1;

gguf-py/gguf/constants.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class Tokenizer:
8888
SCORES = "tokenizer.ggml.scores"
8989
MERGES = "tokenizer.ggml.merges"
9090
BOS_ID = "tokenizer.ggml.bos_token_id"
91-
EOS_ID = "tokenizer.ggml.eos_token_id"
91+
EOS_ID = "tokenizer.ggml.eos_token_id" # recommand eos_id_list
9292
UNK_ID = "tokenizer.ggml.unknown_token_id"
9393
SEP_ID = "tokenizer.ggml.seperator_token_id"
9494
PAD_ID = "tokenizer.ggml.padding_token_id"
@@ -107,6 +107,8 @@ class Tokenizer:
107107
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
108108
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
109109
EOT_ID = "tokenizer.ggml.eot_token_id"
110+
EOS_ID_LIST = "tokenizer.ggml.eos_token_id_list"
111+
110112

111113

112114
#
@@ -1091,7 +1093,7 @@ def get_type(val: Any) -> GGUFValueType:
10911093
KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
10921094
KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES
10931095
KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID
1094-
KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID
1096+
KEY_TOKENIZER_EOS_ID_LIST= Keys.Tokenizer.EOS_ID_LIST
10951097
KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
10961098
KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
10971099
KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID

gguf-py/gguf/gguf_writer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,9 @@ def add_token_scores(self, scores: Sequence[float]) -> None:
510510

511511
def add_bos_token_id(self, id: int) -> None:
512512
self.add_uint32(Keys.Tokenizer.BOS_ID, id)
513-
514-
def add_eos_token_id(self, id: int) -> None:
515-
self.add_uint32(Keys.Tokenizer.EOS_ID, id)
513+
514+
def add_eos_token_id_list(self, id: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
515+
self.add_array(Keys.Tokenizer.EOS_ID_LIST, id)
516516

517517
def add_unk_token_id(self, id: int) -> None:
518518
self.add_uint32(Keys.Tokenizer.UNK_ID, id)

0 commit comments

Comments
 (0)