diff --git a/scripts/gen_vid_diffusion.py b/scripts/gen_vid_diffusion.py index 5c25b0757..46967b123 100644 --- a/scripts/gen_vid_diffusion.py +++ b/scripts/gen_vid_diffusion.py @@ -275,6 +275,7 @@ def generate( bbox_select_list = [] img_tensor_list = [] out_img_list = [] + sequence_count = 0 for img_path, bbox_path in zip(limited_paths_img, limited_paths_bbox): img_in = os.path.join(os.path.dirname(os.path.dirname(paths_in_file)), img_path) bbox_in = os.path.join( @@ -599,7 +600,7 @@ def generate( cond_image = fill_img_with_sketch( img_tensor.unsqueeze(0), mask.unsqueeze(0) ) - elif opt.alg_diffusion_cond_image_creation == "canny": + elif opt.alg_diffusion_cond_image_creation == "computed_sketch": # "canny": clamp = torch.clamp(mask, 0, 1) if cond_in: # mask the background to avoid canny edges around cond image @@ -613,6 +614,7 @@ def generate( high_threshold=alg_diffusion_sketch_canny_thresholds[1], low_threshold_random=-1, high_threshold_random=-1, + select_canny=[1] + [0] * (opt.data_temporal_number_frames - 1), ) if cond_in: # restore background @@ -683,8 +685,22 @@ def generate( else: ref_tensor = None - cond_image_list.append(cond_image) - y_t_list.append(y_t) + if opt.alg_diffusion_cond_image_creation == "computed_sketch": + if sequence_count == 0: + cond_image_list.append(cond_image) + y_t_list.append(y_t) + else: + cond_image_list.append(y_t_list[0]) + y_t_list.append(y_t_list[0]) + if opt.alg_diffusion_cond_image_creation == "y_t": + if sequence_count == 0: + cond_image_list.append(cond_image) + y_t_list.append(y_t) + else: + cond_image_list.append(cond_image) + y_t_list.append(y_t_list[0]) + + sequence_count = sequence_count + 1 y0_tensor_list.append(y0_tensor) mask_list.append(mask) bbox_select_list.append(bbox_select)