5555from diffusers .utils .torch_utils import is_compiled_module
5656
5757
58+ if is_wandb_available ():
59+ import wandb
60+
5861# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5962check_min_version ("0.26.0.dev0" )
6063
6770TORCH_DTYPE_MAPPING = {"fp32" : torch .float32 , "fp16" : torch .float16 , "bf16" : torch .bfloat16 }
6871
6972
73+ def log_validation (
74+ pipeline ,
75+ args ,
76+ accelerator ,
77+ generator ,
78+ global_step ,
79+ is_final_validation = False ,
80+ ):
81+ logger .info (
82+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
83+ f" { args .validation_prompt } ."
84+ )
85+
86+ pipeline = pipeline .to (accelerator .device )
87+ pipeline .set_progress_bar_config (disable = True )
88+
89+ if not is_final_validation :
90+ val_save_dir = os .path .join (args .output_dir , "validation_images" )
91+ if not os .path .exists (val_save_dir ):
92+ os .makedirs (val_save_dir )
93+
94+ original_image = (
95+ lambda image_url_or_path : load_image (image_url_or_path )
96+ if urlparse (image_url_or_path ).scheme
97+ else Image .open (image_url_or_path ).convert ("RGB" )
98+ )(args .val_image_url_or_path )
99+
100+ with torch .autocast (str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16" ):
101+ edited_images = []
102+ # Run inference
103+ for val_img_idx in range (args .num_validation_images ):
104+ a_val_img = pipeline (
105+ args .validation_prompt ,
106+ image = original_image ,
107+ num_inference_steps = 20 ,
108+ image_guidance_scale = 1.5 ,
109+ guidance_scale = 7 ,
110+ generator = generator ,
111+ ).images [0 ]
112+ edited_images .append (a_val_img )
113+ # Save validation images
114+ if not is_final_validation :
115+ a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
116+
117+ for tracker in accelerator .trackers :
118+ if tracker .name == "wandb" :
119+ wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
120+ for edited_image in edited_images :
121+ wandb_table .add_data (wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt )
122+ logger_name = "test" if is_final_validation else "validation"
123+ tracker .log ({logger_name : wandb_table })
124+
125+
70126def import_model_class_from_model_name_or_path (
71127 pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
72128):
@@ -447,11 +503,6 @@ def main():
447503
448504 generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
449505
450- if args .report_to == "wandb" :
451- if not is_wandb_available ():
452- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
453- import wandb
454-
455506 # Make one log on every process with the configuration for debugging.
456507 logging .basicConfig (
457508 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -1111,11 +1162,6 @@ def collate_fn(examples):
11111162 ### BEGIN: Perform validation every `validation_epochs` steps
11121163 if global_step % args .validation_steps == 0 :
11131164 if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1114- logger .info (
1115- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1116- f" { args .validation_prompt } ."
1117- )
1118-
11191165 # create pipeline
11201166 if args .use_ema :
11211167 # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1181,16 @@ def collate_fn(examples):
11351181 variant = args .variant ,
11361182 torch_dtype = weight_dtype ,
11371183 )
1138- pipeline = pipeline .to (accelerator .device )
1139- pipeline .set_progress_bar_config (disable = True )
1140-
1141- # run inference
1142- # Save validation images
1143- val_save_dir = os .path .join (args .output_dir , "validation_images" )
1144- if not os .path .exists (val_save_dir ):
1145- os .makedirs (val_save_dir )
1146-
1147- original_image = (
1148- lambda image_url_or_path : load_image (image_url_or_path )
1149- if urlparse (image_url_or_path ).scheme
1150- else Image .open (image_url_or_path ).convert ("RGB" )
1151- )(args .val_image_url_or_path )
1152- with torch .autocast (
1153- str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16"
1154- ):
1155- edited_images = []
1156- for val_img_idx in range (args .num_validation_images ):
1157- a_val_img = pipeline (
1158- args .validation_prompt ,
1159- image = original_image ,
1160- num_inference_steps = 20 ,
1161- image_guidance_scale = 1.5 ,
1162- guidance_scale = 7 ,
1163- generator = generator ,
1164- ).images [0 ]
1165- edited_images .append (a_val_img )
1166- a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
1167-
1168- for tracker in accelerator .trackers :
1169- if tracker .name == "wandb" :
1170- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1171- for edited_image in edited_images :
1172- wandb_table .add_data (
1173- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1174- )
1175- tracker .log ({"validation" : wandb_table })
1184+
1185+ log_validation (
1186+ pipeline ,
1187+ args ,
1188+ accelerator ,
1189+ generator ,
1190+ global_step ,
1191+ is_final_validation = False ,
1192+ )
1193+
11761194 if args .use_ema :
11771195 # Switch back to the original UNet parameters.
11781196 ema_unet .restore (unet .parameters ())
@@ -1187,7 +1205,6 @@ def collate_fn(examples):
11871205 # Create the pipeline using the trained modules and save it.
11881206 accelerator .wait_for_everyone ()
11891207 if accelerator .is_main_process :
1190- unet = unwrap_model (unet )
11911208 if args .use_ema :
11921209 ema_unet .copy_to (unet .parameters ())
11931210
@@ -1198,10 +1215,11 @@ def collate_fn(examples):
11981215 tokenizer = tokenizer_1 ,
11991216 tokenizer_2 = tokenizer_2 ,
12001217 vae = vae ,
1201- unet = unet ,
1218+ unet = unwrap_model ( unet ) ,
12021219 revision = args .revision ,
12031220 variant = args .variant ,
12041221 )
1222+
12051223 pipeline .save_pretrained (args .output_dir )
12061224
12071225 if args .push_to_hub :
@@ -1212,30 +1230,15 @@ def collate_fn(examples):
12121230 ignore_patterns = ["step_*" , "epoch_*" ],
12131231 )
12141232
1215- if args .validation_prompt is not None :
1216- edited_images = []
1217- pipeline = pipeline .to (accelerator .device )
1218- with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1219- for _ in range (args .num_validation_images ):
1220- edited_images .append (
1221- pipeline (
1222- args .validation_prompt ,
1223- image = original_image ,
1224- num_inference_steps = 20 ,
1225- image_guidance_scale = 1.5 ,
1226- guidance_scale = 7 ,
1227- generator = generator ,
1228- ).images [0 ]
1229- )
1230-
1231- for tracker in accelerator .trackers :
1232- if tracker .name == "wandb" :
1233- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1234- for edited_image in edited_images :
1235- wandb_table .add_data (
1236- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1237- )
1238- tracker .log ({"test" : wandb_table })
1233+ if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1234+ log_validation (
1235+ pipeline ,
1236+ args ,
1237+ accelerator ,
1238+ generator ,
1239+ global_step = None ,
1240+ is_final_validation = True ,
1241+ )
12391242
12401243 accelerator .end_training ()
12411244
0 commit comments