Skip to content

Commit 50293bc

Browse files
ngxsonNexesenex
authored andcommitted
Add MiniCPM, Deepseek V2 chat template + clean up llama_chat_apply_template_internal (ggml-org#8172)
* tmp_contains * minicpm chat template * add DeepSeek Lite template * change deepseek-lite to deepseek2 * correct code comment * correct code from master branch
1 parent 1b1ed62 commit 50293bc

File tree

2 files changed

+198
-17
lines changed

2 files changed

+198
-17
lines changed

llama.cpp

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22271,24 +22271,27 @@ static int32_t llama_chat_apply_template_internal(
2227122271
std::string & dest, bool add_ass) {
2227222272
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
2227322273
std::stringstream ss;
22274-
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
22274+
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
22275+
return tmpl.find(haystack) != std::string::npos;
22276+
};
22277+
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
2227522278
// chatml template
2227622279
for (auto message : chat) {
2227722280
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
2227822281
}
2227922282
if (add_ass) {
2228022283
ss << "<|im_start|>assistant\n";
2228122284
}
22282-
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
22285+
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
2228322286
// llama2 template and its variants
2228422287
// [variant] support system message
22285-
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
22288+
bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "mistral";
2228622289
// [variant] space before + after response
22287-
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
22290+
bool space_around_response = tmpl_contains("' ' + eos_token");
2228822291
// [variant] add BOS inside history
22289-
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
22292+
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
2229022293
// [variant] trim spaces from the input message
22291-
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
22294+
bool strip_message = tmpl_contains("content.strip()");
2229222295
// construct the prompt
2229322296
bool is_inside_turn = true; // skip BOS at the beginning
2229422297
ss << "[INST] ";
@@ -22314,7 +22317,7 @@ static int32_t llama_chat_apply_template_internal(
2231422317
}
2231522318
}
2231622319
// llama2 templates seem to not care about "add_generation_prompt"
22317-
} else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
22320+
} else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
2231822321
// Phi 3
2231922322
for (auto message : chat) {
2232022323
std::string role(message->role);
@@ -22323,15 +22326,15 @@ static int32_t llama_chat_apply_template_internal(
2232322326
if (add_ass) {
2232422327
ss << "<|assistant|>\n";
2232522328
}
22326-
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
22329+
} else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
2232722330
// zephyr template
2232822331
for (auto message : chat) {
2232922332
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
2233022333
}
2233122334
if (add_ass) {
2233222335
ss << "<|assistant|>\n";
2233322336
}
22334-
} else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
22337+
} else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
2233522338
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
2233622339
for (auto message : chat) {
2233722340
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
@@ -22340,7 +22343,7 @@ static int32_t llama_chat_apply_template_internal(
2234022343
if (add_ass) {
2234122344
ss << "<s>assistant\n";
2234222345
}
22343-
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
22346+
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("<start_of_turn>")) {
2234422347
// google/gemma-7b-it
2234522348
std::string system_prompt = "";
2234622349
for (auto message : chat) {
@@ -22362,7 +22365,7 @@ static int32_t llama_chat_apply_template_internal(
2236222365
if (add_ass) {
2236322366
ss << "<start_of_turn>model\n";
2236422367
}
22365-
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
22368+
} else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
2236622369
// OrionStarAI/Orion-14B-Chat
2236722370
std::string system_prompt = "";
2236822371
for (auto message : chat) {
@@ -22382,7 +22385,7 @@ static int32_t llama_chat_apply_template_internal(
2238222385
ss << message->content << "</s>";
2238322386
}
2238422387
}
22385-
} else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
22388+
} else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
2238622389
// openchat/openchat-3.5-0106,
2238722390
for (auto message : chat) {
2238822391
std::string role(message->role);
@@ -22396,13 +22399,13 @@ static int32_t llama_chat_apply_template_internal(
2239622399
if (add_ass) {
2239722400
ss << "GPT4 Correct Assistant:";
2239822401
}
22399-
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
22402+
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
2240022403
// eachadea/vicuna-13b-1.1 (and Orca variant)
2240122404
for (auto message : chat) {
2240222405
std::string role(message->role);
2240322406
if (role == "system") {
2240422407
// Orca-Vicuna variant uses a system prefix
22405-
if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
22408+
if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
2240622409
ss << "SYSTEM: " << message->content << "\n";
2240722410
} else {
2240822411
ss << message->content << "\n\n";
@@ -22416,7 +22419,7 @@ static int32_t llama_chat_apply_template_internal(
2241622419
if (add_ass) {
2241722420
ss << "ASSISTANT:";
2241822421
}
22419-
} else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
22422+
} else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
2242022423
// deepseek-ai/deepseek-coder-33b-instruct
2242122424
for (auto message : chat) {
2242222425
std::string role(message->role);
@@ -22431,7 +22434,7 @@ static int32_t llama_chat_apply_template_internal(
2243122434
if (add_ass) {
2243222435
ss << "### Response:\n";
2243322436
}
22434-
} else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
22437+
} else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
2243522438
// CohereForAI/c4ai-command-r-plus
2243622439
for (auto message : chat) {
2243722440
std::string role(message->role);
@@ -22446,7 +22449,7 @@ static int32_t llama_chat_apply_template_internal(
2244622449
if (add_ass) {
2244722450
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
2244822451
}
22449-
} else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
22452+
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
2245022453
// Llama 3
2245122454
for (auto message : chat) {
2245222455
std::string role(message->role);
@@ -22455,6 +22458,33 @@ static int32_t llama_chat_apply_template_internal(
2245522458
if (add_ass) {
2245622459
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
2245722460
}
22461+
} else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
22462+
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
22463+
for (auto message : chat) {
22464+
std::string role(message->role);
22465+
if (role == "user") {
22466+
ss << u8"<用户>";
22467+
ss << trim(message->content);
22468+
ss << "<AI>";
22469+
} else {
22470+
ss << trim(message->content);
22471+
}
22472+
}
22473+
} else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
22474+
// DeepSeek-V2
22475+
for (auto message : chat) {
22476+
std::string role(message->role);
22477+
if (role == "system") {
22478+
ss << message->content << "\n\n";
22479+
} else if (role == "user") {
22480+
ss << "User: " << message->content << "\n\n";
22481+
} else if (role == "assistant") {
22482+
ss << "Assistant: " << message->content << u8"<|end▁of▁sentence|>";
22483+
}
22484+
}
22485+
if (add_ass) {
22486+
ss << "Assistant:";
22487+
}
2245822488
} else {
2245922489
// template not supported
2246022490
return -1;

0 commit comments

Comments
 (0)