Skip to content
26 changes: 26 additions & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ Conversation RWKV() {
return conv;
}

Conversation RWKVWorld() {
const std::string kUserPrefix = "User: ";
const std::string kAssistantPrefix = "Assistant: Hi. I am your assistant and I will provide expert "
"full response in full details. Please feel free to ask any question and I will always answer it.";
const std::string kDoubleNewLine = "\n\n";
const std::string prompt =
"(" + kUserPrefix + "hi" + kDoubleNewLine + kAssistantPrefix + kDoubleNewLine + ")";
Conversation conv;
conv.name = "rwkv-world";
conv.system = prompt;
conv.roles = {"User", "Assistant"};
conv.messages = {};
conv.separator_style = SeparatorStyle::kSepRoleMsg;
conv.offset = 0;
conv.seps = {"\n\n"};
conv.role_msg_sep = ": ";
conv.role_empty_sep = ":";
conv.stop_str = "\n\n";
// TODO(mlc-team): add eos to mlc-chat-config
// and remove eos from stop token setting.
conv.stop_tokens = {0};
conv.add_bos = false;
return conv;
}

Conversation Gorilla() {
Conversation conv;
conv.name = "gorilla_v0";
Expand Down Expand Up @@ -532,6 +557,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"vicuna_v1.1", VicunaV11},
{"conv_one_shot", ConvOneShot},
{"redpajama_chat", RedPajamaChat},
{"rwkv_world", RWKVWorld},
{"rwkv", RWKV},
{"gorilla", Gorilla},
{"guanaco", Guanaco},
Expand Down
6 changes: 6 additions & 0 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
std::filesystem::path path(_path);
std::filesystem::path sentencepiece;
std::filesystem::path huggingface;
std::filesystem::path rwkvworld;
CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
if (std::filesystem::is_directory(path)) {
sentencepiece = path / "tokenizer.model";
huggingface = path / "tokenizer.json";
rwkvworld = path / "tokenizer_model";
// Check ByteLevelBPE
{
std::filesystem::path merges_path = path / "merges.txt";
Expand All @@ -76,13 +78,17 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
} else {
sentencepiece = path.parent_path() / "tokenizer.model";
huggingface = path.parent_path() / "tokenizer.json";
rwkvworld = path.parent_path() / "tokenizer_model";
}
if (std::filesystem::exists(sentencepiece)) {
return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
}
if (std::filesystem::exists(huggingface)) {
return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
}
if (std::filesystem::exists(rwkvworld)) {
return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
}
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
}

Expand Down
4 changes: 2 additions & 2 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def build_model_from_args(args: argparse.Namespace):
mod, param_manager, params, model_config = minigpt.get_model(args)
elif args.model_category == "gptj":
mod, param_manager, params, model_config = gptj.get_model(args, config)
elif args.model_category == "rwkv":
elif args.model_category == "rwkv" or args.model_category == "rwkv_world":
mod, param_manager, params, model_config = rwkv.get_model(args, config)
elif args.model_category == "chatglm":
mod, param_manager, params, model_config = chatglm.get_model(args, config)
Expand All @@ -572,7 +572,7 @@ def build_model_from_args(args: argparse.Namespace):
utils.save_params(new_params, args.artifact_path)
if args.model_category != "minigpt":
utils.copy_tokenizer(args)
if args.model_category == "rwkv":
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
# TODO: refactor config into model definition
dump_mlc_chat_config(args, top_p=0.6, temperature=1.2, repetition_penalty=0.996)
else:
Expand Down
3 changes: 3 additions & 0 deletions mlc_llm/dispatch/dispatch_tir_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, model: str):

elif model == "rwkv":
lookup = None

elif model == "rwkv_world":
lookup = None

elif model == "gptj":
lookup = None
Expand Down
7 changes: 7 additions & 0 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ def argparse_postproc_common(args: argparse.Namespace) -> None:
"moss-moon-003-sft": "gptj",
"moss-moon-003-base": "gptj",
"rwkv-": "rwkv",
"rwkv_world": "rwkv_world",
"minigpt": "minigpt",
}
try:
with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f:
config = json.load(i_f)
args.model_category = config["model_type"]
model_path_lower = args.model_path.lower()
if "rwkv" in model_path_lower and "world" in model_path_lower:
args.model_category = "rwkv_world"
except Exception:
args.model_category = ""
model = args.model.lower()
if "rwkv" in model and "world" in model:
model = "rwkv_world"
for prefix, override_category in model_category_override.items():
if model.startswith(prefix):
args.model_category = override_category
Expand All @@ -67,6 +73,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None:
"gpt-j-": "LM",
"open_llama": "LM",
"rwkv-": "rwkv",
"rwkv_world": "rwkv_world",
"gorilla-": "gorilla",
"guanaco": "guanaco",
"wizardlm-7b": "wizardlm_7b", # first get rid of 7b
Expand Down