Skip to content

Commit 9fa817f

Browse files
committed
add mmproj file support
1 parent 887055e commit 9fa817f

File tree

5 files changed

+161
-41
lines changed

5 files changed

+161
-41
lines changed

examples/cli/main.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct SDParams {
6060
std::string clip_vision_path;
6161
std::string t5xxl_path;
6262
std::string qwen2vl_path;
63+
std::string qwen2vl_vision_path;
6364
std::string diffusion_model_path;
6465
std::string high_noise_diffusion_model_path;
6566
std::string vae_path;
@@ -146,6 +147,7 @@ void print_params(SDParams params) {
146147
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
147148
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
148149
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
150+
printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str());
149151
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
150152
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
151153
printf(" vae_path: %s\n", params.vae_path.c_str());
@@ -218,6 +220,7 @@ void print_usage(int argc, const char* argv[]) {
218220
printf(" --clip_vision path to the clip-vision encoder\n");
219221
printf(" --t5xxl path to the t5xxl text encoder\n");
220222
printf(" --qwen2vl path to the qwen2vl text encoder\n");
223+
printf(" --qwen2vl_vision path to the qwen2vl vit\n");
221224
printf(" --vae [VAE] path to vae\n");
222225
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
223226
printf(" --control-net [CONTROL_PATH] path to control net model\n");
@@ -488,6 +491,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
488491
{"", "--clip_vision", "", &params.clip_vision_path},
489492
{"", "--t5xxl", "", &params.t5xxl_path},
490493
{"", "--qwen2vl", "", &params.qwen2vl_path},
494+
{"", "--qwen2vl_vision", "", &params.qwen2vl_vision_path},
491495
{"", "--diffusion-model", "", &params.diffusion_model_path},
492496
{"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path},
493497
{"", "--vae", "", &params.vae_path},
@@ -947,7 +951,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
947951
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
948952
}
949953
parameter_string += ", ";
950-
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) {
954+
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
951955
if (!te.empty()) {
952956
parameter_string += "TE: " + sd_basename(te) + ", ";
953957
}
@@ -1322,6 +1326,7 @@ int main(int argc, const char* argv[]) {
13221326
params.clip_vision_path.c_str(),
13231327
params.t5xxl_path.c_str(),
13241328
params.qwen2vl_path.c_str(),
1329+
params.qwen2vl_vision_path.c_str(),
13251330
params.diffusion_model_path.c_str(),
13261331
params.high_noise_diffusion_model_path.c_str(),
13271332
params.vae_path.c_str(),

model.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,24 @@ std::unordered_map<std::string, std::string> qwenvl_name_map{
211211
{"output_norm.", "model.norm."},
212212
};
213213

214+
std::unordered_map<std::string, std::string> qwenvl_vision_name_map{
215+
{"mm.", "merger.mlp."},
216+
{"v.post_ln.", "merger.ln_q."},
217+
{"v.patch_embd.weight", "patch_embed.proj.0.weight"},
218+
{"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"},
219+
{"v.patch_embd.weight.1", "patch_embed.proj.1.weight"},
220+
{"v.blk.", "blocks."},
221+
{"attn_q.", "attn.q_proj."},
222+
{"attn_k.", "attn.k_proj."},
223+
{"attn_v.", "attn.v_proj."},
224+
{"attn_out.", "attn.proj."},
225+
{"ffn_down.", "mlp.down_proj."},
226+
{"ffn_gate.", "mlp.gate_proj."},
227+
{"ffn_up.", "mlp.up_proj."},
228+
{"ln1.", "norm1."},
229+
{"ln2.", "norm2."},
230+
};
231+
214232
std::string convert_cond_model_name(const std::string& name) {
215233
std::string new_name = name;
216234
std::string prefix;
@@ -269,10 +287,19 @@ std::string convert_cond_model_name(const std::string& name) {
269287
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
270288
}
271289
} else if (contains(name, "qwen2vl")) {
272-
for (auto kv : qwenvl_name_map) {
273-
size_t pos = new_name.find(kv.first);
274-
if (pos != std::string::npos) {
275-
new_name.replace(pos, kv.first.size(), kv.second);
290+
if (contains(name, "qwen2vl.visual")) {
291+
for (auto kv : qwenvl_vision_name_map) {
292+
size_t pos = new_name.find(kv.first);
293+
if (pos != std::string::npos) {
294+
new_name.replace(pos, kv.first.size(), kv.second);
295+
}
296+
}
297+
} else {
298+
for (auto kv : qwenvl_name_map) {
299+
size_t pos = new_name.find(kv.first);
300+
if (pos != std::string::npos) {
301+
new_name.replace(pos, kv.first.size(), kv.second);
302+
}
276303
}
277304
}
278305
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {

0 commit comments

Comments
 (0)