@@ -260,23 +260,44 @@ static size_t validate_utf8(const std::string& text) {
260260// template utils
261261//
262262
263- // format rerank task: [BOS]query[EOS][SEP]doc[EOS]
264- static llama_tokens format_rerank (const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
263+ // format rerank task:
264+ // - using SEP token: [BOS]query[EOS][SEP]doc[EOS]
265+ // - using prompt: <rerank_prefix>query<rerank_suffix>doc
266+ static llama_tokens format_rerank (const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
267+ const llama_vocab * vocab = llama_model_get_vocab (model);
265268 llama_tokens result;
266269
267- // Get EOS token - use SEP token as fallback if EOS is not available
268- llama_token eos_token = llama_vocab_eos (vocab);
269- if (eos_token == LLAMA_TOKEN_NULL) {
270- eos_token = llama_vocab_sep (vocab);
271- }
270+ if (llama_vocab_sep (vocab) != LLAMA_TOKEN_NULL) {
271+ // Get EOS token - use SEP token as fallback if EOS is not available
272+ llama_token eos_token = llama_vocab_eos (vocab);
273+ if (eos_token == LLAMA_TOKEN_NULL) {
274+ eos_token = llama_vocab_sep (vocab);
275+ }
276+
277+ result.reserve (doc.size () + query.size () + 4 );
278+ result.push_back (llama_vocab_bos (vocab));
279+ result.insert (result.end (), query.begin (), query.end ());
280+ result.push_back (eos_token);
281+ result.push_back (llama_vocab_sep (vocab));
282+ result.insert (result.end (), doc.begin (), doc.end ());
283+ result.push_back (eos_token);
284+ } else {
285+ // using prompt template
286+ const char * prefix = llama_model_chat_template (model, " rerank_prefix" );
287+ const char * suffix = llama_model_chat_template (model, " rerank_suffix" );
288+
289+ if (prefix == NULL && suffix == NULL ) {
290+ throw std::runtime_error (" Rerank prompt template not found in the model\n " );
291+ }
272292
273- result.reserve (doc.size () + query.size () + 4 );
274- result.push_back (llama_vocab_bos (vocab));
275- result.insert (result.end (), query.begin (), query.end ());
276- result.push_back (eos_token);
277- result.push_back (llama_vocab_sep (vocab));
278- result.insert (result.end (), doc.begin (), doc.end ());
279- result.push_back (eos_token);
293+ const llama_tokens prefix_tokens = prefix ? common_tokenize (vocab, prefix, true , false ) : llama_tokens ();
294+ const llama_tokens suffix_tokens = suffix ? common_tokenize (vocab, suffix, false , false ) : llama_tokens ();
295+ result.reserve (prefix_tokens.size () + query.size () + suffix_tokens.size () + doc.size ());
296+ result.insert (result.end (), prefix_tokens.begin (), prefix_tokens.end ());
297+ result.insert (result.end (), query.begin (), query.end ());
298+ result.insert (result.end (), suffix_tokens.begin (), suffix_tokens.end ());
299+ result.insert (result.end (), doc.begin (), doc.end ());
300+ }
280301
281302 return result;
282303}
0 commit comments