Skip to content

Commit c409ca0

Browse files
authored
Enable vulkan for RWKV and some misc updates (octoml#264)
This PR enables vulkan for rwkv, removes the vulkan from local detection as it can cause cross platform issues. Introduce max_gen_len parameter for chat
1 parent 60e2176 commit c409ca0

File tree

6 files changed

+49
-38
lines changed

6 files changed

+49
-38
lines changed

build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def dump_default_mlc_chat_config(args):
291291
config["repetition_penalty"] = 1.0
292292
config["top_p"] = 0.95
293293
config["mean_gen_len"] = 128
294+
config["max_gen_len"] = 512
294295
config["shift_fill_factor"] = 0.3
295296
config["tokenizer_files"] = utils.get_tokenizer_files(params_path)
296297

cpp/conversation.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ picojson::value Conversation::SerializeToJSON() const {
148148
return picojson::value(config);
149149
}
150150

151-
std::string Conversation::SerializeToJSONStr() const { return SerializeToJSON().serialize(true); }
151+
std::string Conversation::GetConfigJSON() const { return SerializeToJSON().serialize(true); }
152152

153153
} // namespace llm
154154
} // namespace mlc

cpp/conversation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class Conversation {
135135
* \brief Serialize the Conversation to JSON String.
136136
* \return A string storing the serialized conversation in JSON format.
137137
*/
138-
std::string SerializeToJSONStr() const;
138+
std::string GetConfigJSON() const;
139139

140140
/*!
141141
* \brief Get the entire prompt array

cpp/llm_chat.cc

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ class LLMChat {
179179
} else {
180180
CHECK(partial_update) << "Key \"mean_gen_len\" not found.";
181181
}
182+
// NOTE: for backward compact
183+
// max gen len is optonal
184+
if (config.count("max_gen_len")) {
185+
CHECK(config["max_gen_len"].is<int64_t>());
186+
this->max_gen_len_ = config["max_gen_len"].get<int64_t>();
187+
}
182188
if (config.count("shift_fill_factor")) {
183189
CHECK(config["shift_fill_factor"].is<double>());
184190
this->shift_fill_factor_ = config["shift_fill_factor"].get<double>();
@@ -219,18 +225,8 @@ class LLMChat {
219225
LoadJSONOverride(config_json, partial_update);
220226
}
221227

222-
picojson::value SerializeToJSON() const {
223-
picojson::object config;
224-
config["temperature"] = picojson::value(this->temperature_);
225-
config["repetition_penalty"] = picojson::value(this->repetition_penalty_);
226-
config["top_p"] = picojson::value(this->top_p_);
227-
config["mean_gen_len"] = picojson::value(this->mean_gen_len_);
228-
config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_);
229-
config["conv_config"] = this->conversation_.SerializeToJSON();
230-
return picojson::value(config);
231-
}
232228

233-
std::string SerializeToJSONStr() const { return SerializeToJSON().serialize(true); }
229+
std::string GetConfigJSON() const { return SerializeConfigToJSONValue().serialize(true); }
234230

235231
/*!
236232
* \brief Reload model, tokenizers and configurations from the specified model path.
@@ -595,6 +591,17 @@ class LLMChat {
595591
}
596592

597593
private:
594+
picojson::value SerializeConfigToJSONValue() const {
595+
picojson::object config;
596+
config["temperature"] = picojson::value(this->temperature_);
597+
config["repetition_penalty"] = picojson::value(this->repetition_penalty_);
598+
config["top_p"] = picojson::value(this->top_p_);
599+
config["mean_gen_len"] = picojson::value(this->mean_gen_len_);
600+
config["max_gen_len"] = picojson::value(this->max_gen_len_);
601+
config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_);
602+
config["conv_config"] = this->conversation_.SerializeToJSON();
603+
return picojson::value(config);
604+
}
598605
/*!
599606
* \brief Sample output token from logits on device
600607
*/
@@ -662,8 +669,10 @@ class LLMChat {
662669
}
663670
}
664671
}
665-
// TODO(mlc-team): add another per convo seq len trigger
666-
if (total_seq_len_ >= max_window_size_) {
672+
673+
if (static_cast<int64_t>(output_ids_.size()) >= max_gen_len_) {
674+
stop_triggered_ = true;
675+
} else if (total_seq_len_ >= max_window_size_) {
667676
stop_triggered_ = true;
668677
}
669678
if (stop_triggered_) {
@@ -783,7 +792,7 @@ class LLMChat {
783792
// total sequence len,
784793
int64_t total_seq_len_{0};
785794
// max window size, mean generation length
786-
int64_t max_window_size_{768}, mean_gen_len_{128};
795+
int64_t max_window_size_{768}, mean_gen_len_{128}, max_gen_len_{512};
787796
// shift window fill factor
788797
double shift_fill_factor_{0.3};
789798
// temperature
@@ -927,9 +936,9 @@ class LLMChatModule : public ModuleNode {
927936
} else if (name == "reset_runtime_stats") {
928937
return PackedFunc(
929938
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->ResetRuntimeStats(); });
930-
} else if (name == "serialize_config") {
939+
} else if (name == "get_config_json") {
931940
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
932-
*rv = GetChat()->SerializeToJSONStr();
941+
*rv = GetChat()->GetConfigJSON();
933942
});
934943
} else {
935944
return PackedFunc(nullptr);

mlc_llm/utils.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class Quantization:
4040
"q0f16": Quantization(
4141
name="q0f16", mode="no", sym=False, storage_nbit=-1, model_dtype="float16"
4242
),
43-
"q8f16": Quantization(
44-
name="q8f16", mode="uint8", sym=False, storage_nbit=-1, model_dtype="float16"
43+
"q8f16_0": Quantization(
44+
name="q8f16_0", mode="uint8", sym=False, storage_nbit=-1, model_dtype="float16"
4545
),
4646
}
4747

@@ -306,6 +306,7 @@ def _detect_local_vulkan():
306306
"thread_warp_size": dev.warp_size,
307307
"supports_float16": 1,
308308
"supports_int16": 1,
309+
"supports_int8": 1,
309310
"supports_16bit_buffer": 1,
310311
}
311312
)
@@ -424,23 +425,23 @@ def compile_metal(src, target):
424425
args.target = target
425426
args.target_kind = "iphone"
426427
elif args.target == "vulkan":
427-
target = _detect_local_vulkan()
428-
if target is None:
429-
print("Cannot detect local Vulkan GPU target! Falling back...")
430-
target = tvm.target.Target(
431-
tvm.target.Target(
432-
{
433-
"kind": "vulkan",
434-
"max_threads_per_block": 256,
435-
"max_shared_memory_per_block": 32768,
436-
"thread_warp_size": 1,
437-
"supports_float16": 1,
438-
"supports_int16": 1,
439-
"supports_16bit_buffer": 1,
440-
}
441-
),
442-
host="llvm",
443-
)
428+
target = tvm.target.Target(
429+
tvm.target.Target(
430+
{
431+
"kind": "vulkan",
432+
"max_threads_per_block": 256,
433+
"max_shared_memory_per_block": 32768,
434+
"thread_warp_size": 1,
435+
"supports_float16": 1,
436+
"supports_int16": 1,
437+
"supports_int8": 1,
438+
"supports_8bit_buffer": 1,
439+
"supports_16bit_buffer": 1,
440+
"supports_storage_buffer_storage_class": 1
441+
}
442+
),
443+
host="llvm",
444+
)
444445
args.target = target
445446
args.target_kind = args.target.kind.default_keys[0]
446447
elif args.target == "webgpu":

tests/cpp/conv_unittest.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
void _TestConversationJSONRoundTrip(std::string templ_name) {
55
mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name);
6-
std::string conv_json = conv.SerializeToJSONStr();
6+
std::string conv_json = conv.GetConfigJSON();
77
mlc::llm::Conversation conv_new;
88
conv_new.LoadJSONOverride(conv_json, false);
99
ASSERT_EQ(conv, conv_new);

0 commit comments

Comments
 (0)