Skip to content

Commit

Permalink
fix: diffusion inference for images > 8bits
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Jul 3, 2024
1 parent ad0d99d commit aefdc38
Showing 1 changed file with 89 additions and 41 deletions.
130 changes: 89 additions & 41 deletions scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
import torch
import torchvision.transforms as T
from PIL import Image
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm

from PIL import Image

jg_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")
sys.path.append(jg_dir)
from segment_anything import SamPredictor
Expand Down Expand Up @@ -110,6 +113,9 @@ def load_model(
if opt.model_type == "palette":
model.set_new_sampling_method(sampling_method)

if opt.alg_diffusion_task == "pix2pix":
opt.alg_diffusion_cond_image_creation = "pix2pix"

model = model.to(device)
return model, opt

Expand Down Expand Up @@ -241,9 +247,14 @@ def generate(
# Load image

# reading image
img = cv2.imread(img_in)
if opt.data_image_bits > 8:
img = Image.open(img_in) # we use PIL
local_img_width, local_img_height = img.size
else:
img = cv2.imread(img_in)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
local_img_width, local_img_height = img.shape[:2]
img_orig = img.copy()
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# reading the mask
mask = None
Expand Down Expand Up @@ -382,11 +393,17 @@ def generate(
img, mask = np.array(img), np.array(mask)

if img_width > 0 and img_height > 0:
img = cv2.resize(img, (img_width, img_height))
if mask is not None:
mask = cv2.resize(mask, (img_width, img_height))
if ref is not None:
ref = cv2.resize(ref, (img_width, img_height))
if img_height != local_img_height or img_width != local_img_width:
if opt.data_image_bits > 8:
print(
"Requested image size differs from training crop size, resizing is not supported for images with more than 8 bits per channel"
)
exit(1)
img = cv2.resize(img, (img_width, img_height))
if mask is not None:
mask = cv2.resize(mask, (img_width, img_height))
if ref is not None:
ref = cv2.resize(ref, (img_width, img_height))

if logger:
logger.info(
Expand Down Expand Up @@ -456,11 +473,19 @@ def generate(

# preprocessing to torch
totensor = transforms.ToTensor()
tranlist = [
totensor,
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# resize,
]
if opt.data_image_bits > 8:
tranlist = [totensor, torchvision.transforms.v2.ToDtype(torch.float32)]
bit_scaling = 2**opt.data_image_bits - 1
tranlist += [transforms.Lambda(lambda img: img * (1 / float(bit_scaling)))]
tranlist += [
transforms.Normalize((0.5,), (0.5,))
] # XXX: > 8bit, mono canal only for now
else:
tranlist = [
totensor,
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# resize,
]

tran = transforms.Compose(tranlist)
img_tensor = tran(img).clone().detach()
Expand Down Expand Up @@ -582,10 +607,15 @@ def generate(
cond_image = transform_hr(cond_image).detach()
elif opt.alg_diffusion_cond_image_creation == "pix2pix":
# use same interpolation as get_transform
transform_hr = T.Resize(
(img_height, img_width), interpolation=T.InterpolationMode.BICUBIC
)
cond_image = transform_hr(img_tensor.unsqueeze(0)).detach()
if (img_height > 0 and img_height != opt.data_crop_size) or (
img_width > 0 and img_width != opt.data_crop_size
):
transform_hr = T.Resize(
(img_height, img_width), interpolation=T.InterpolationMode.BICUBIC
)
cond_image = transform_hr(img_tensor.unsqueeze(0)).detach()
else:
cond_image = img_tensor.unsqueeze(0).detach()

if mask is None:
cl_mask = None
Expand All @@ -599,7 +629,9 @@ def generate(
cl_mask,
)
if mask == None:
img_tensor = None
y0_tensor = None
else:
y0_tensor = img_tensor

if opt.model_type == "palette":
if "class" in model.denoise_fn.conditioning:
Expand All @@ -617,7 +649,7 @@ def generate(
out_tensor, visu = model.restoration(
y_cond=cond_image,
y_t=y_t,
y_0=img_tensor,
y_0=y0_tensor,
mask=mask,
cls=cls_tensor,
ref=ref_tensor,
Expand All @@ -628,7 +660,10 @@ def generate(
)
elif opt.model_type == "cm" or opt.model_type == "cm_gan":
sampling_sigmas = (80.0, 24.4, 5.84, 0.9, 0.661)

out_tensor = model.restoration(y_t, cond_image, sampling_sigmas, mask)

# XXX: !=8bit images are converted to 8bit RGB for now
out_img = to_np(
out_tensor
) # out_img = out_img.detach().data.cpu().float().numpy()[0]
Expand Down Expand Up @@ -668,29 +703,42 @@ def generate(
cond_img = to_np(cond_image)

if write:
cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_orig)
if cond_image is not None:
cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img)
cv2.imwrite(os.path.join(dir_out, name + "_generated.png"), out_img_real_size)
cv2.imwrite(os.path.join(dir_out, name + "_y_t.png"), to_np(y_t))
if mask is not None:
cv2.imwrite(os.path.join(dir_out, name + "_y_0.png"), to_np(img_tensor))
cv2.imwrite(os.path.join(dir_out, name + "_generated_crop.png"), out_img)
cv2.imwrite(os.path.join(dir_out, name + "_mask.png"), to_np(mask))
if ref is not None:
cv2.imwrite(os.path.join(dir_out, name + "_ref_orig.png"), ref_orig)
if cond_in:
# crop before cond image
orig_crop = img_orig[
bbox_select[1] : bbox_select[3], bbox_select[0] : bbox_select[2]
]
cv2.imwrite(os.path.join(dir_out, name + "_orig_crop.png"), orig_crop)
if bbox_in:
with open(os.path.join(dir_out, name + "_orig_bbox.json"), "w") as out:
out.write(json.dumps(bbox))
if generated_bbox:
with open(os.path.join(dir_out, name + "_generated_bbox.json"), "w") as out:
out.write(json.dumps(generated_bbox))
if opt.data_image_bits > 8:
img_np = to_np(img_tensor) # comes from PIL
cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_np)
if cond_image is not None:
cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img)
cv2.imwrite(os.path.join(dir_out, name + "_generated.png"), out_img_resized)
else:
cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_orig)
if cond_image is not None:
cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img)
cv2.imwrite(
os.path.join(dir_out, name + "_generated.png"), out_img_real_size
)
cv2.imwrite(os.path.join(dir_out, name + "_y_t.png"), to_np(y_t))
if mask is not None:
cv2.imwrite(os.path.join(dir_out, name + "_y_0.png"), to_np(img_tensor))
cv2.imwrite(
os.path.join(dir_out, name + "_generated_crop.png"), out_img
)
cv2.imwrite(os.path.join(dir_out, name + "_mask.png"), to_np(mask))
if ref is not None:
cv2.imwrite(os.path.join(dir_out, name + "_ref_orig.png"), ref_orig)
if cond_in:
# crop before cond image
orig_crop = img_orig[
bbox_select[1] : bbox_select[3], bbox_select[0] : bbox_select[2]
]
cv2.imwrite(os.path.join(dir_out, name + "_orig_crop.png"), orig_crop)
if bbox_in:
with open(os.path.join(dir_out, name + "_orig_bbox.json"), "w") as out:
out.write(json.dumps(bbox))
if generated_bbox:
with open(
os.path.join(dir_out, name + "_generated_bbox.json"), "w"
) as out:
out.write(json.dumps(generated_bbox))

print("Successfully generated image ", name)

Expand Down

0 comments on commit aefdc38

Please sign in to comment.