Skip to content
46 changes: 43 additions & 3 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ Conversation RWKV() {
return conv;
}

Conversation RWKVWorld() {
const std::string kUserPrefix = "User: ";
const std::string kAssistantPrefix = "Assistant: Hi. I am your assistant and I will provide expert "
"full response in full details. Please feel free to ask any question and I will always answer it.";
const std::string kDoubleNewLine = "\n\n";
const std::string prompt =
"(" + kUserPrefix + "hi" + kDoubleNewLine + kAssistantPrefix + kDoubleNewLine + ")";
Conversation conv;
conv.name = "rwkv-world";
conv.system = prompt;
conv.roles = {"User", "Assistant"};
conv.messages = {};
conv.separator_style = SeparatorStyle::kSepRoleMsg;
conv.offset = 0;
conv.seps = {"\n\n"};
conv.role_msg_sep = ": ";
conv.role_empty_sep = ":";
conv.stop_str = "\n\n";
// TODO(mlc-team): add eos to mlc-chat-config
// and remove eos from stop token setting.
conv.stop_tokens = {0};
conv.add_bos = false;
return conv;
}

Conversation Gorilla() {
Conversation conv;
conv.name = "gorilla_v0";
Expand Down Expand Up @@ -521,9 +546,23 @@ Conversation GLM() {

} // namespace

bool containsBothWordsInRWKVWorld(const std::string& str) {
std::string lowerStr = str;
std::transform(str.begin(), str.end(), lowerStr.begin(), ::tolower);

return (lowerStr.find("rwkv") != std::string::npos) &&
(lowerStr.find("world") != std::string::npos);
}

using ConvFactory = Conversation (*)();

Conversation Conversation::FromTemplate(const std::string& name) {
Conversation Conversation::FromTemplate(const std::string& name, const std::string& model_name) {
std::string new_name = name;
// The RWKV World series of models have different prompts compared to other models like RWKV Raven.
if (containsBothWordsInRWKVWorld(model_name)){
new_name = "rwkv_world";
}

static std::unordered_map<std::string, ConvFactory> factory = {
{"llama_default", LlamaDefault},
{"llama-2", Llama2},
Expand All @@ -532,6 +571,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"vicuna_v1.1", VicunaV11},
{"conv_one_shot", ConvOneShot},
{"redpajama_chat", RedPajamaChat},
{"rwkv_world", RWKVWorld},
{"rwkv", RWKV},
{"gorilla", Gorilla},
{"guanaco", Guanaco},
Expand All @@ -548,9 +588,9 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"wizard_coder_or_math", WizardCoderOrMATH},
{"glm", GLM},
};
auto it = factory.find(name);
auto it = factory.find(new_name);
if (it == factory.end()) {
LOG(FATAL) << "Unknown conversation template: " << name;
LOG(FATAL) << "Unknown conversation template: " << new_name;
}
return it->second();
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ class Conversation {
/**
* \brief Create conversation from existing registered template.
* \param name The template name.
* \param model_name The model name.
*/
static Conversation FromTemplate(const std::string& name);
static Conversation FromTemplate(const std::string& name, const std::string& model_name);

/*!
* \brief Load JSON config in raw string and overrides options.
Expand Down
15 changes: 13 additions & 2 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
std::filesystem::path path(_path);
std::filesystem::path sentencepiece;
std::filesystem::path huggingface;
std::filesystem::path rwkvworld;
CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
if (std::filesystem::is_directory(path)) {
sentencepiece = path / "tokenizer.model";
huggingface = path / "tokenizer.json";
rwkvworld = path / "tokenizer_model";
// Check ByteLevelBPE
{
std::filesystem::path merges_path = path / "merges.txt";
Expand All @@ -76,13 +78,17 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
} else {
sentencepiece = path.parent_path() / "tokenizer.model";
huggingface = path.parent_path() / "tokenizer.json";
rwkvworld = path.parent_path() / "tokenizer_model";
}
if (std::filesystem::exists(sentencepiece)) {
return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
}
if (std::filesystem::exists(huggingface)) {
return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
}
if (std::filesystem::exists(rwkvworld)) {
return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
}
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
}

Expand Down Expand Up @@ -153,6 +159,11 @@ class LLMChat {
*/
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false) {
picojson::object config = config_json.get<picojson::object>();
std::string model_name;
if (config.count("model_name")) {
CHECK(config["model_name"].is<std::string>());
model_name = config["model_name"].get<std::string>();
}
if (config.count("temperature")) {
CHECK(config["temperature"].is<double>());
this->temperature_ = config["temperature"].get<double>();
Expand Down Expand Up @@ -193,7 +204,7 @@ class LLMChat {
if (config.count("conv_template")) {
ICHECK(config["conv_template"].is<std::string>());
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
this->conversation_ = Conversation::FromTemplate(conv_template, model_name);
if (config.count("conv_config")) {
// conv_config can override conv_template
this->conversation_.LoadJSONOverride(config["conv_config"], true);
Expand Down Expand Up @@ -354,7 +365,7 @@ class LLMChat {
if (this->max_window_size_ == -1) {
this->max_window_size_ = std::numeric_limits<int64_t>::max();
}
this->conversation_ = Conversation::FromTemplate(conv_template);
this->conversation_ = Conversation::FromTemplate(conv_template, this->model_name_);
this->temperature_ = temperature;
this->top_p_ = top_p;
this->mean_gen_len_ = mean_gen_len;
Expand Down