@@ -41,13 +41,15 @@ const char* modes_str[] = {
4141 " img_gen" ,
4242 " vid_gen" ,
4343 " convert" ,
44+ " upscale" ,
4445};
45- #define SD_ALL_MODES_STR " img_gen, vid_gen, convert"
46+ #define SD_ALL_MODES_STR " img_gen, vid_gen, convert, upscale "
4647
4748enum SDMode {
4849 IMG_GEN,
4950 VID_GEN,
5051 CONVERT,
52+ UPSCALE,
5153 MODE_COUNT
5254};
5355
@@ -204,7 +206,7 @@ void print_usage(int argc, const char* argv[]) {
204206 printf (" \n " );
205207 printf (" arguments:\n " );
206208 printf (" -h, --help show this help message and exit\n " );
207- printf (" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n " );
209+ printf (" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n " );
208210 printf (" -t, --threads N number of threads to use during computation (default: -1)\n " );
209211 printf (" If threads <= 0, then threads will be set to the number of CPU physical cores\n " );
210212 printf (" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n " );
@@ -219,7 +221,7 @@ void print_usage(int argc, const char* argv[]) {
219221 printf (" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n " );
220222 printf (" --control-net [CONTROL_PATH] path to control net model\n " );
221223 printf (" --embd-dir [EMBEDDING_PATH] path to embeddings\n " );
222- printf (" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n " );
224+ printf (" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n " );
223225 printf (" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n " );
224226 printf (" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n " );
225227 printf (" If not specified, the default is the type of the weight file\n " );
@@ -817,13 +819,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
817819 params.n_threads = get_num_physical_cores ();
818820 }
819821
820- if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt .length () == 0 ) {
822+ if (( params.mode == IMG_GEN || params.mode == VID_GEN) && params.prompt .length () == 0 ) {
821823 fprintf (stderr, " error: the following arguments are required: prompt\n " );
822824 print_usage (argc, argv);
823825 exit (1 );
824826 }
825827
826- if (params.model_path .length () == 0 && params.diffusion_model_path .length () == 0 ) {
828+ if (params.mode != UPSCALE && params. model_path .length () == 0 && params.diffusion_model_path .length () == 0 ) {
827829 fprintf (stderr, " error: the following arguments are required: model_path/diffusion_model\n " );
828830 print_usage (argc, argv);
829831 exit (1 );
@@ -883,6 +885,17 @@ void parse_args(int argc, const char** argv, SDParams& params) {
883885 exit (1 );
884886 }
885887
888+ if (params.mode == UPSCALE) {
889+ if (params.esrgan_path .length () == 0 ) {
890+ fprintf (stderr, " error: upscale mode needs an upscaler model (--upscale-model)\n " );
891+ exit (1 );
892+ }
893+ if (params.init_image_path .length () == 0 ) {
894+ fprintf (stderr, " error: upscale mode needs an init image (--init-img)\n " );
895+ exit (1 );
896+ }
897+ }
898+
886899 if (params.seed < 0 ) {
887900 srand ((int )time (NULL ));
888901 params.seed = rand ();
@@ -1352,76 +1365,92 @@ int main(int argc, const char* argv[]) {
13521365 params.flow_shift ,
13531366 };
13541367
1355- sd_ctx_t * sd_ctx = new_sd_ctx (&sd_ctx_params);
1368+ sd_image_t * results = nullptr ;
1369+ int num_results = 0 ;
13561370
1357- if (sd_ctx == NULL ) {
1358- printf (" new_sd_ctx_t failed\n " );
1359- release_all_resources ();
1360- return 1 ;
1361- }
1371+ if (params.mode == UPSCALE) {
1372+ num_results = 1 ;
1373+ results = (sd_image_t *)calloc (num_results, sizeof (sd_image_t ));
1374+ if (results == NULL ) {
1375+ printf (" failed to allocate results array\n " );
1376+ release_all_resources ();
1377+ return 1 ;
1378+ }
13621379
1363- if (params.sample_params .sample_method == SAMPLE_METHOD_DEFAULT) {
1364- params.sample_params .sample_method = sd_get_default_sample_method (sd_ctx);
1365- }
1380+ results[0 ] = init_image;
1381+ init_image.data = NULL ;
1382+ } else {
1383+ sd_ctx_t * sd_ctx = new_sd_ctx (&sd_ctx_params);
13661384
1367- sd_image_t * results;
1368- int num_results = 1 ;
1369- if (params.mode == IMG_GEN) {
1370- sd_img_gen_params_t img_gen_params = {
1371- params.prompt .c_str (),
1372- params.negative_prompt .c_str (),
1373- params.clip_skip ,
1374- init_image,
1375- ref_images.data (),
1376- (int )ref_images.size (),
1377- params.increase_ref_index ,
1378- mask_image,
1379- params.width ,
1380- params.height ,
1381- params.sample_params ,
1382- params.strength ,
1383- params.seed ,
1384- params.batch_count ,
1385- control_image,
1386- params.control_strength ,
1387- {
1388- pmid_images.data (),
1389- (int )pmid_images.size (),
1390- params.pm_id_embed_path .c_str (),
1391- params.pm_style_strength ,
1392- }, // pm_params
1393- params.vae_tiling_params ,
1394- };
1395-
1396- results = generate_image (sd_ctx, &img_gen_params);
1397- num_results = params.batch_count ;
1398- } else if (params.mode == VID_GEN) {
1399- sd_vid_gen_params_t vid_gen_params = {
1400- params.prompt .c_str (),
1401- params.negative_prompt .c_str (),
1402- params.clip_skip ,
1403- init_image,
1404- end_image,
1405- control_frames.data (),
1406- (int )control_frames.size (),
1407- params.width ,
1408- params.height ,
1409- params.sample_params ,
1410- params.high_noise_sample_params ,
1411- params.moe_boundary ,
1412- params.strength ,
1413- params.seed ,
1414- params.video_frames ,
1415- params.vace_strength ,
1416- };
1417-
1418- results = generate_video (sd_ctx, &vid_gen_params, &num_results);
1419- }
1385+ if (sd_ctx == NULL ) {
1386+ printf (" new_sd_ctx_t failed\n " );
1387+ release_all_resources ();
1388+ return 1 ;
1389+ }
1390+
1391+ if (params.sample_params .sample_method == SAMPLE_METHOD_DEFAULT) {
1392+ params.sample_params .sample_method = sd_get_default_sample_method (sd_ctx);
1393+ }
1394+
1395+ if (params.mode == IMG_GEN) {
1396+ sd_img_gen_params_t img_gen_params = {
1397+ params.prompt .c_str (),
1398+ params.negative_prompt .c_str (),
1399+ params.clip_skip ,
1400+ init_image,
1401+ ref_images.data (),
1402+ (int )ref_images.size (),
1403+ params.increase_ref_index ,
1404+ mask_image,
1405+ params.width ,
1406+ params.height ,
1407+ params.sample_params ,
1408+ params.strength ,
1409+ params.seed ,
1410+ params.batch_count ,
1411+ control_image,
1412+ params.control_strength ,
1413+ {
1414+ pmid_images.data (),
1415+ (int )pmid_images.size (),
1416+ params.pm_id_embed_path .c_str (),
1417+ params.pm_style_strength ,
1418+ }, // pm_params
1419+ params.vae_tiling_params ,
1420+ };
1421+
1422+ results = generate_image (sd_ctx, &img_gen_params);
1423+ num_results = params.batch_count ;
1424+ } else if (params.mode == VID_GEN) {
1425+ sd_vid_gen_params_t vid_gen_params = {
1426+ params.prompt .c_str (),
1427+ params.negative_prompt .c_str (),
1428+ params.clip_skip ,
1429+ init_image,
1430+ end_image,
1431+ control_frames.data (),
1432+ (int )control_frames.size (),
1433+ params.width ,
1434+ params.height ,
1435+ params.sample_params ,
1436+ params.high_noise_sample_params ,
1437+ params.moe_boundary ,
1438+ params.strength ,
1439+ params.seed ,
1440+ params.video_frames ,
1441+ params.vace_strength ,
1442+ };
1443+
1444+ results = generate_video (sd_ctx, &vid_gen_params, &num_results);
1445+ }
1446+
1447+ if (results == NULL ) {
1448+ printf (" generate failed\n " );
1449+ free_sd_ctx (sd_ctx);
1450+ return 1 ;
1451+ }
14201452
1421- if (results == NULL ) {
1422- printf (" generate failed\n " );
14231453 free_sd_ctx (sd_ctx);
1424- return 1 ;
14251454 }
14261455
14271456 int upscale_factor = 4 ; // unused for RealESRGAN_x4plus_anime_6B.pth
@@ -1434,7 +1463,7 @@ int main(int argc, const char* argv[]) {
14341463 if (upscaler_ctx == NULL ) {
14351464 printf (" new_upscaler_ctx failed\n " );
14361465 } else {
1437- for (int i = 0 ; i < params. batch_count ; i++) {
1466+ for (int i = 0 ; i < num_results ; i++) {
14381467 if (results[i].data == NULL ) {
14391468 continue ;
14401469 }
@@ -1520,7 +1549,6 @@ int main(int argc, const char* argv[]) {
15201549 results[i].data = NULL ;
15211550 }
15221551 free (results);
1523- free_sd_ctx (sd_ctx);
15241552
15251553 release_all_resources ();
15261554
0 commit comments