From aefdc384c8e973b83c7b9ac661d069face054208 Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Mon, 1 Jul 2024 10:42:58 +0000 Subject: [PATCH] fix: diffusion inference for images > 8bits --- scripts/gen_single_image_diffusion.py | 130 ++++++++++++++++++-------- 1 file changed, 89 insertions(+), 41 deletions(-) diff --git a/scripts/gen_single_image_diffusion.py b/scripts/gen_single_image_diffusion.py index 50d8a0407..3fd9b6d26 100644 --- a/scripts/gen_single_image_diffusion.py +++ b/scripts/gen_single_image_diffusion.py @@ -15,10 +15,13 @@ import torch import torchvision.transforms as T from PIL import Image +import torchvision from torchvision import transforms from torchvision.utils import save_image from tqdm import tqdm +from PIL import Image + jg_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") sys.path.append(jg_dir) from segment_anything import SamPredictor @@ -110,6 +113,9 @@ def load_model( if opt.model_type == "palette": model.set_new_sampling_method(sampling_method) + if opt.alg_diffusion_task == "pix2pix": + opt.alg_diffusion_cond_image_creation = "pix2pix" + model = model.to(device) return model, opt @@ -241,9 +247,14 @@ def generate( # Load image # reading image - img = cv2.imread(img_in) + if opt.data_image_bits > 8: + img = Image.open(img_in) # we use PIL + local_img_width, local_img_height = img.size + else: + img = cv2.imread(img_in) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + local_img_width, local_img_height = img.shape[:2] img_orig = img.copy() - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # reading the mask mask = None @@ -382,11 +393,17 @@ def generate( img, mask = np.array(img), np.array(mask) if img_width > 0 and img_height > 0: - img = cv2.resize(img, (img_width, img_height)) - if mask is not None: - mask = cv2.resize(mask, (img_width, img_height)) - if ref is not None: - ref = cv2.resize(ref, (img_width, img_height)) + if img_height != local_img_height or img_width != local_img_width: + if opt.data_image_bits > 8: + print( + "Requested image size differs from training crop size, resizing is not supported for images with more than 8 bits per channel" + ) + exit(1) + img = cv2.resize(img, (img_width, img_height)) + if mask is not None: + mask = cv2.resize(mask, (img_width, img_height)) + if ref is not None: + ref = cv2.resize(ref, (img_width, img_height)) if logger: logger.info( @@ -456,11 +473,19 @@ def generate( # preprocessing to torch totensor = transforms.ToTensor() - tranlist = [ - totensor, - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - # resize, - ] + if opt.data_image_bits > 8: + tranlist = [totensor, torchvision.transforms.v2.ToDtype(torch.float32)] + bit_scaling = 2**opt.data_image_bits - 1 + tranlist += [transforms.Lambda(lambda img: img * (1 / float(bit_scaling)))] + tranlist += [ + transforms.Normalize((0.5,), (0.5,)) + ] # XXX: > 8bit, mono canal only for now + else: + tranlist = [ + totensor, + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + # resize, + ] tran = transforms.Compose(tranlist) img_tensor = tran(img).clone().detach() @@ -582,10 +607,15 @@ def generate( cond_image = transform_hr(cond_image).detach() elif opt.alg_diffusion_cond_image_creation == "pix2pix": # use same interpolation as get_transform - transform_hr = T.Resize( - (img_height, img_width), interpolation=T.InterpolationMode.BICUBIC - ) - cond_image = transform_hr(img_tensor.unsqueeze(0)).detach() + if (img_height > 0 and img_height != opt.data_crop_size) or ( + img_width > 0 and img_width != opt.data_crop_size + ): + transform_hr = T.Resize( + (img_height, img_width), interpolation=T.InterpolationMode.BICUBIC + ) + cond_image = transform_hr(img_tensor.unsqueeze(0)).detach() + else: + cond_image = img_tensor.unsqueeze(0).detach() if mask is None: cl_mask = None @@ -599,7 +629,9 @@ def generate( cl_mask, ) if mask == None: - img_tensor = None + y0_tensor = None + else: + y0_tensor = img_tensor if opt.model_type == "palette": if "class" in model.denoise_fn.conditioning: @@ -617,7 +649,7 @@ def generate( out_tensor, visu = model.restoration( y_cond=cond_image, y_t=y_t, - y_0=img_tensor, + y_0=y0_tensor, mask=mask, cls=cls_tensor, ref=ref_tensor, @@ -628,7 +660,10 @@ def generate( ) elif opt.model_type == "cm" or opt.model_type == "cm_gan": sampling_sigmas = (80.0, 24.4, 5.84, 0.9, 0.661) + out_tensor = model.restoration(y_t, cond_image, sampling_sigmas, mask) + + # XXX: !=8bit images are converted to 8bit RGB for now out_img = to_np( out_tensor ) # out_img = out_img.detach().data.cpu().float().numpy()[0] @@ -668,29 +703,42 @@ def generate( cond_img = to_np(cond_image) if write: - cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_orig) - if cond_image is not None: - cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img) - cv2.imwrite(os.path.join(dir_out, name + "_generated.png"), out_img_real_size) - cv2.imwrite(os.path.join(dir_out, name + "_y_t.png"), to_np(y_t)) - if mask is not None: - cv2.imwrite(os.path.join(dir_out, name + "_y_0.png"), to_np(img_tensor)) - cv2.imwrite(os.path.join(dir_out, name + "_generated_crop.png"), out_img) - cv2.imwrite(os.path.join(dir_out, name + "_mask.png"), to_np(mask)) - if ref is not None: - cv2.imwrite(os.path.join(dir_out, name + "_ref_orig.png"), ref_orig) - if cond_in: - # crop before cond image - orig_crop = img_orig[ - bbox_select[1] : bbox_select[3], bbox_select[0] : bbox_select[2] - ] - cv2.imwrite(os.path.join(dir_out, name + "_orig_crop.png"), orig_crop) - if bbox_in: - with open(os.path.join(dir_out, name + "_orig_bbox.json"), "w") as out: - out.write(json.dumps(bbox)) - if generated_bbox: - with open(os.path.join(dir_out, name + "_generated_bbox.json"), "w") as out: - out.write(json.dumps(generated_bbox)) + if opt.data_image_bits > 8: + img_np = to_np(img_tensor) # comes from PIL + cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_np) + if cond_image is not None: + cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img) + cv2.imwrite(os.path.join(dir_out, name + "_generated.png"), out_img_resized) + else: + cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_orig) + if cond_image is not None: + cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img) + cv2.imwrite( + os.path.join(dir_out, name + "_generated.png"), out_img_real_size + ) + cv2.imwrite(os.path.join(dir_out, name + "_y_t.png"), to_np(y_t)) + if mask is not None: + cv2.imwrite(os.path.join(dir_out, name + "_y_0.png"), to_np(img_tensor)) + cv2.imwrite( + os.path.join(dir_out, name + "_generated_crop.png"), out_img + ) + cv2.imwrite(os.path.join(dir_out, name + "_mask.png"), to_np(mask)) + if ref is not None: + cv2.imwrite(os.path.join(dir_out, name + "_ref_orig.png"), ref_orig) + if cond_in: + # crop before cond image + orig_crop = img_orig[ + bbox_select[1] : bbox_select[3], bbox_select[0] : bbox_select[2] + ] + cv2.imwrite(os.path.join(dir_out, name + "_orig_crop.png"), orig_crop) + if bbox_in: + with open(os.path.join(dir_out, name + "_orig_bbox.json"), "w") as out: + out.write(json.dumps(bbox)) + if generated_bbox: + with open( + os.path.join(dir_out, name + "_generated_bbox.json"), "w" + ) as out: + out.write(json.dumps(generated_bbox)) print("Successfully generated image ", name)