Skip to content

Commit 5b261b9

Browse files
wbrunaleejet
andauthored
feat: add a stand-alone upscale mode (#865)
* feat: add a stand-alone upscale mode * fix prompt option check * format code * update README.md --------- Co-authored-by: leejet <[email protected]>
1 parent e70d020 commit 5b261b9

File tree

2 files changed

+102
-74
lines changed

2 files changed

+102
-74
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ usage: ./bin/sd [arguments]
285285
286286
arguments:
287287
-h, --help show this help message and exit
288-
-M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen
288+
-M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen
289289
-t, --threads N number of threads to use during computation (default: -1)
290290
If threads <= 0, then threads will be set to the number of CPU physical cores
291291
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
@@ -300,7 +300,7 @@ arguments:
300300
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
301301
--control-net [CONTROL_PATH] path to control net model
302302
--embd-dir [EMBEDDING_PATH] path to embeddings
303-
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
303+
--upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
304304
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
305305
--type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)
306306
If not specified, the default is the type of the weight file

examples/cli/main.cpp

Lines changed: 100 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ const char* modes_str[] = {
4141
"img_gen",
4242
"vid_gen",
4343
"convert",
44+
"upscale",
4445
};
45-
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert"
46+
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"
4647

4748
enum SDMode {
4849
IMG_GEN,
4950
VID_GEN,
5051
CONVERT,
52+
UPSCALE,
5153
MODE_COUNT
5254
};
5355

@@ -204,7 +206,7 @@ void print_usage(int argc, const char* argv[]) {
204206
printf("\n");
205207
printf("arguments:\n");
206208
printf(" -h, --help show this help message and exit\n");
207-
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n");
209+
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n");
208210
printf(" -t, --threads N number of threads to use during computation (default: -1)\n");
209211
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
210212
printf(" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n");
@@ -219,7 +221,7 @@ void print_usage(int argc, const char* argv[]) {
219221
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
220222
printf(" --control-net [CONTROL_PATH] path to control net model\n");
221223
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
222-
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
224+
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
223225
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
224226
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
225227
printf(" If not specified, the default is the type of the weight file\n");
@@ -817,13 +819,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
817819
params.n_threads = get_num_physical_cores();
818820
}
819821

820-
if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt.length() == 0) {
822+
if ((params.mode == IMG_GEN || params.mode == VID_GEN) && params.prompt.length() == 0) {
821823
fprintf(stderr, "error: the following arguments are required: prompt\n");
822824
print_usage(argc, argv);
823825
exit(1);
824826
}
825827

826-
if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
828+
if (params.mode != UPSCALE && params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
827829
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
828830
print_usage(argc, argv);
829831
exit(1);
@@ -883,6 +885,17 @@ void parse_args(int argc, const char** argv, SDParams& params) {
883885
exit(1);
884886
}
885887

888+
if (params.mode == UPSCALE) {
889+
if (params.esrgan_path.length() == 0) {
890+
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
891+
exit(1);
892+
}
893+
if (params.init_image_path.length() == 0) {
894+
fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n");
895+
exit(1);
896+
}
897+
}
898+
886899
if (params.seed < 0) {
887900
srand((int)time(NULL));
888901
params.seed = rand();
@@ -1352,76 +1365,92 @@ int main(int argc, const char* argv[]) {
13521365
params.flow_shift,
13531366
};
13541367

1355-
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
1368+
sd_image_t* results = nullptr;
1369+
int num_results = 0;
13561370

1357-
if (sd_ctx == NULL) {
1358-
printf("new_sd_ctx_t failed\n");
1359-
release_all_resources();
1360-
return 1;
1361-
}
1371+
if (params.mode == UPSCALE) {
1372+
num_results = 1;
1373+
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
1374+
if (results == NULL) {
1375+
printf("failed to allocate results array\n");
1376+
release_all_resources();
1377+
return 1;
1378+
}
13621379

1363-
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
1364-
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
1365-
}
1380+
results[0] = init_image;
1381+
init_image.data = NULL;
1382+
} else {
1383+
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
13661384

1367-
sd_image_t* results;
1368-
int num_results = 1;
1369-
if (params.mode == IMG_GEN) {
1370-
sd_img_gen_params_t img_gen_params = {
1371-
params.prompt.c_str(),
1372-
params.negative_prompt.c_str(),
1373-
params.clip_skip,
1374-
init_image,
1375-
ref_images.data(),
1376-
(int)ref_images.size(),
1377-
params.increase_ref_index,
1378-
mask_image,
1379-
params.width,
1380-
params.height,
1381-
params.sample_params,
1382-
params.strength,
1383-
params.seed,
1384-
params.batch_count,
1385-
control_image,
1386-
params.control_strength,
1387-
{
1388-
pmid_images.data(),
1389-
(int)pmid_images.size(),
1390-
params.pm_id_embed_path.c_str(),
1391-
params.pm_style_strength,
1392-
}, // pm_params
1393-
params.vae_tiling_params,
1394-
};
1395-
1396-
results = generate_image(sd_ctx, &img_gen_params);
1397-
num_results = params.batch_count;
1398-
} else if (params.mode == VID_GEN) {
1399-
sd_vid_gen_params_t vid_gen_params = {
1400-
params.prompt.c_str(),
1401-
params.negative_prompt.c_str(),
1402-
params.clip_skip,
1403-
init_image,
1404-
end_image,
1405-
control_frames.data(),
1406-
(int)control_frames.size(),
1407-
params.width,
1408-
params.height,
1409-
params.sample_params,
1410-
params.high_noise_sample_params,
1411-
params.moe_boundary,
1412-
params.strength,
1413-
params.seed,
1414-
params.video_frames,
1415-
params.vace_strength,
1416-
};
1417-
1418-
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
1419-
}
1385+
if (sd_ctx == NULL) {
1386+
printf("new_sd_ctx_t failed\n");
1387+
release_all_resources();
1388+
return 1;
1389+
}
1390+
1391+
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
1392+
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
1393+
}
1394+
1395+
if (params.mode == IMG_GEN) {
1396+
sd_img_gen_params_t img_gen_params = {
1397+
params.prompt.c_str(),
1398+
params.negative_prompt.c_str(),
1399+
params.clip_skip,
1400+
init_image,
1401+
ref_images.data(),
1402+
(int)ref_images.size(),
1403+
params.increase_ref_index,
1404+
mask_image,
1405+
params.width,
1406+
params.height,
1407+
params.sample_params,
1408+
params.strength,
1409+
params.seed,
1410+
params.batch_count,
1411+
control_image,
1412+
params.control_strength,
1413+
{
1414+
pmid_images.data(),
1415+
(int)pmid_images.size(),
1416+
params.pm_id_embed_path.c_str(),
1417+
params.pm_style_strength,
1418+
}, // pm_params
1419+
params.vae_tiling_params,
1420+
};
1421+
1422+
results = generate_image(sd_ctx, &img_gen_params);
1423+
num_results = params.batch_count;
1424+
} else if (params.mode == VID_GEN) {
1425+
sd_vid_gen_params_t vid_gen_params = {
1426+
params.prompt.c_str(),
1427+
params.negative_prompt.c_str(),
1428+
params.clip_skip,
1429+
init_image,
1430+
end_image,
1431+
control_frames.data(),
1432+
(int)control_frames.size(),
1433+
params.width,
1434+
params.height,
1435+
params.sample_params,
1436+
params.high_noise_sample_params,
1437+
params.moe_boundary,
1438+
params.strength,
1439+
params.seed,
1440+
params.video_frames,
1441+
params.vace_strength,
1442+
};
1443+
1444+
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
1445+
}
1446+
1447+
if (results == NULL) {
1448+
printf("generate failed\n");
1449+
free_sd_ctx(sd_ctx);
1450+
return 1;
1451+
}
14201452

1421-
if (results == NULL) {
1422-
printf("generate failed\n");
14231453
free_sd_ctx(sd_ctx);
1424-
return 1;
14251454
}
14261455

14271456
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
@@ -1434,7 +1463,7 @@ int main(int argc, const char* argv[]) {
14341463
if (upscaler_ctx == NULL) {
14351464
printf("new_upscaler_ctx failed\n");
14361465
} else {
1437-
for (int i = 0; i < params.batch_count; i++) {
1466+
for (int i = 0; i < num_results; i++) {
14381467
if (results[i].data == NULL) {
14391468
continue;
14401469
}
@@ -1520,7 +1549,6 @@ int main(int argc, const char* argv[]) {
15201549
results[i].data = NULL;
15211550
}
15221551
free(results);
1523-
free_sd_ctx(sd_ctx);
15241552

15251553
release_all_resources();
15261554

0 commit comments

Comments
 (0)