Skip to content
26 changes: 26 additions & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ Conversation RWKV() {
return conv;
}

Conversation RWKVWorld() {
const std::string kUserPrefix = "User: ";
const std::string kAssistantPrefix = "Assistant: Hi. I am your assistant and I will provide expert "
"full response in full details. Please feel free to ask any question and I will always answer it.";
const std::string kDoubleNewLine = "\n\n";
const std::string prompt =
"(" + kUserPrefix + "hi" + kDoubleNewLine + kAssistantPrefix + kDoubleNewLine + ")";
Conversation conv;
conv.name = "rwkv-world";
conv.system = prompt;
conv.roles = {"User", "Assistant"};
conv.messages = {};
conv.separator_style = SeparatorStyle::kSepRoleMsg;
conv.offset = 0;
conv.seps = {"\n\n"};
conv.role_msg_sep = ": ";
conv.role_empty_sep = ":";
conv.stop_str = "\n\n";
// TODO(mlc-team): add eos to mlc-chat-config
// and remove eos from stop token setting.
conv.stop_tokens = {0};
conv.add_bos = false;
return conv;
}

Conversation Gorilla() {
Conversation conv;
conv.name = "gorilla_v0";
Expand Down Expand Up @@ -532,6 +557,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"vicuna_v1.1", VicunaV11},
{"conv_one_shot", ConvOneShot},
{"redpajama_chat", RedPajamaChat},
{"rwkv_world", RWKVWorld},
{"rwkv", RWKV},
{"gorilla", Gorilla},
{"guanaco", Guanaco},
Expand Down
35 changes: 33 additions & 2 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
std::filesystem::path path(_path);
std::filesystem::path sentencepiece;
std::filesystem::path huggingface;
std::filesystem::path rwkvworld;
CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
if (std::filesystem::is_directory(path)) {
sentencepiece = path / "tokenizer.model";
huggingface = path / "tokenizer.json";
rwkvworld = path / "tokenizer_model";
// Check ByteLevelBPE
{
std::filesystem::path merges_path = path / "merges.txt";
Expand All @@ -76,13 +78,17 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
} else {
sentencepiece = path.parent_path() / "tokenizer.model";
huggingface = path.parent_path() / "tokenizer.json";
rwkvworld = path.parent_path() / "tokenizer_model";
}
if (std::filesystem::exists(sentencepiece)) {
return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
}
if (std::filesystem::exists(huggingface)) {
return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
}
if (std::filesystem::exists(rwkvworld)) {
return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
}
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
}

Expand Down Expand Up @@ -115,6 +121,14 @@ inline std::string Concat(const std::vector<std::string>& inputs) {
return os.str();
}

bool containsBothWords(const std::string& str) {
std::string lowerStr = str;
std::transform(str.begin(), str.end(), lowerStr.begin(), ::tolower);

return (lowerStr.find("rwkv") != std::string::npos) &&
(lowerStr.find("world") != std::string::npos);
}

//------------------------------
// Chat module
//------------------------------
Expand Down Expand Up @@ -153,6 +167,11 @@ class LLMChat {
*/
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false) {
picojson::object config = config_json.get<picojson::object>();
std::string model_name;
if (config.count("model_name")) {
CHECK(config["model_name"].is<std::string>());
model_name = config["model_name"].get<std::string>();
}
if (config.count("temperature")) {
CHECK(config["temperature"].is<double>());
this->temperature_ = config["temperature"].get<double>();
Expand Down Expand Up @@ -193,7 +212,13 @@ class LLMChat {
if (config.count("conv_template")) {
ICHECK(config["conv_template"].is<std::string>());
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
// The RWKV World series of models have different prompts compared to other models like RWKV Raven.
if (containsBothWords(model_name)){
this->conversation_ = Conversation::FromTemplate("rwkv_world");
}
else{
this->conversation_ = Conversation::FromTemplate(conv_template);
}
if (config.count("conv_config")) {
// conv_config can override conv_template
this->conversation_.LoadJSONOverride(config["conv_config"], true);
Expand Down Expand Up @@ -354,7 +379,13 @@ class LLMChat {
if (this->max_window_size_ == -1) {
this->max_window_size_ = std::numeric_limits<int64_t>::max();
}
this->conversation_ = Conversation::FromTemplate(conv_template);
// The RWKV World series of models have different prompts compared to other models like RWKV Raven.
if(containsBothWords(this->model_name_)){
this->conversation_ = Conversation::FromTemplate("rwkv-world");
}
else{
this->conversation_ = Conversation::FromTemplate(conv_template);
}
this->temperature_ = temperature;
this->top_p_ = top_p;
this->mean_gen_len_ = mean_gen_len;
Expand Down