@@ -2417,14 +2417,21 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
2417
2417
}
2418
2418
}
2419
2419
2420
+ const int n_eos = llama_n_eos (llama_get_model (lctx));
2421
+ std::vector<int32_t > eos_tokens (n_eos, 0 );
2422
+ int32_t * eos_ptr = eos_tokens.data ();
2423
+ llama_token_eos (llama_get_model (lctx), eos_ptr);
2420
2424
if (params.ignore_eos ) {
2421
- params.sparams .logit_bias [llama_token_eos (model)] = -INFINITY;
2425
+ for (int32_t i = 0 ; i < n_eos; ++i) {
2426
+ params.sparams .logit_bias [eos_ptr[i]] = -INFINITY;
2427
+ }
2422
2428
}
2423
2429
2424
2430
if (params.warmup ) {
2425
2431
LOG (" warming up the model with an empty run\n " );
2426
2432
2427
- std::vector<llama_token> tmp = { llama_token_bos (model), llama_token_eos (model), };
2433
+ std::vector<llama_token> tmp = { llama_token_bos (model) };
2434
+ tmp.insert (tmp.end (), eos_tokens.begin (), eos_tokens.end ());
2428
2435
llama_decode (lctx, llama_batch_get_one (tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , 0 ));
2429
2436
llama_kv_cache_clear (lctx);
2430
2437
llama_synchronize (lctx);
@@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
3357
3364
fprintf (stream, " hellaswag: %s # default: false\n " , params.hellaswag ? " true" : " false" );
3358
3365
fprintf (stream, " hellaswag_tasks: %zu # default: 400\n " , params.hellaswag_tasks );
3359
3366
3360
- const auto logit_bias_eos = sparams.logit_bias .find (llama_token_eos (llama_get_model (lctx)));
3361
- const bool ignore_eos = logit_bias_eos != sparams.logit_bias .end () && logit_bias_eos->second == -INFINITY;
3367
+ const int n_eos = llama_n_eos (llama_get_model (lctx));
3368
+ std::vector<int32_t > eos_tokens (n_eos, 0 );
3369
+ int32_t * eos_ptr = eos_tokens.data ();
3370
+ llama_token_eos (llama_get_model (lctx), eos_ptr);
3371
+ bool ignore_eos = false ;
3372
+ for (auto eos: eos_tokens) {
3373
+ const auto logit_bias_eos = sparams.logit_bias .find (eos);
3374
+ if (logit_bias_eos != sparams.logit_bias .end () && logit_bias_eos->second == -INFINITY) {
3375
+ ignore_eos = true ;
3376
+ }
3377
+ }
3362
3378
fprintf (stream, " ignore_eos: %s # default: false\n " , ignore_eos ? " true" : " false" );
3363
3379
3364
3380
yaml_dump_string_multiline (stream, " in_prefix" , params.input_prefix .c_str ());
@@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
3371
3387
3372
3388
fprintf (stream, " logit_bias:\n " );
3373
3389
for (std::pair<llama_token, float > lb : sparams.logit_bias ) {
3374
- if (ignore_eos && lb. first == logit_bias_eos-> first ) {
3390
+ if (ignore_eos && std::count (eos_tokens. begin (), eos_tokens. end (), lb. first ) ) {
3375
3391
continue ;
3376
3392
}
3377
3393
fprintf (stream, " %d: %f" , lb.first , lb.second );
0 commit comments