Skip to content

Commit

Permalink
fix: 🐛 colored image mask input
Browse files Browse the repository at this point in the history
- mask input was completely broken, now it works.
- size can now be different from the input images, it will crop or expand accordingly.
- the image and mask tensors still must be of the same size
  • Loading branch information
melMass committed Jan 5, 2024
1 parent 6c5e5d3 commit 30c4311
Showing 1 changed file with 100 additions and 35 deletions.
135 changes: 100 additions & 35 deletions nodes/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
# log.warning("cv2.ximgproc.guidedFilter not found, use opencv-contrib-python")


def gaussian_kernel(kernel_size: int, sigma_x: float, sigma_y: float, device=None):
def gaussian_kernel(
kernel_size: int, sigma_x: float, sigma_y: float, device=None
):
x, y = torch.meshgrid(
torch.linspace(-1, 1, kernel_size, device=device),
torch.linspace(-1, 1, kernel_size, device=device),
Expand Down Expand Up @@ -115,7 +117,9 @@ def hsv_adjustment(image: torch.Tensor, hue, saturation, value):
return pil2tensor(out)

@staticmethod
def hsv_adjustment_tensor_not_working(image: torch.Tensor, hue, saturation, value):
def hsv_adjustment_tensor_not_working(
image: torch.Tensor, hue, saturation, value
):
"""Abandonning for now"""
image = image.squeeze(0).permute(2, 0, 1)

Expand Down Expand Up @@ -344,11 +348,16 @@ def do_sharp(
(sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius),
"reflect",
)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
sharpened = F.conv2d(
tensor_image, kernel, padding=center, groups=channels
)

# Remove padding
sharpened = sharpened[
:, :, sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius
:,
:,
sharpen_radius:-sharpen_radius,
sharpen_radius:-sharpen_radius,
]

sharpened = sharpened.permute(0, 2, 3, 1)
Expand Down Expand Up @@ -408,7 +417,9 @@ def render_mask(self, mask, color, background):
for m in masks:
_mask = Image.fromarray(m).convert("L")

log.debug(f"Converted mask to PIL Image format, size: {_mask.size}")
log.debug(
f"Converted mask to PIL Image format, size: {_mask.size}"
)

image = Image.new("RGBA", _mask.size, color=color)
# apply the mask
Expand All @@ -425,8 +436,11 @@ def render_mask(self, mask, color, background):
return (pil2tensor(images),)


from typing import Optional


class ColoredImage:
"""Constant color image of given size"""
"""Constant color image of given size."""

def __init__(self) -> None:
pass
Expand All @@ -451,40 +465,82 @@ def INPUT_TYPES(cls):

FUNCTION = "render_img"

def resize_and_crop(self, img, target_size):
# Calculate scaling factors for both dimensions
scale_x = target_size[0] / img.width
scale_y = target_size[1] / img.height

# Use the smaller scaling factor to maintain aspect ratio
scale = max(scale_x, scale_y)

# Resize the image based on calculated scale
new_size = (int(img.width * scale), int(img.height * scale))
img = img.resize(new_size, Image.ANTIALIAS)

# Calculate cropping coordinates
left = (img.width - target_size[0]) / 2
top = (img.height - target_size[1]) / 2
right = (img.width + target_size[0]) / 2
bottom = (img.height + target_size[1]) / 2

# Crop and return the image
return img.crop((left, top, right, bottom))

def resize_and_crop_thumbnails(self, img, target_size):
img.thumbnail(target_size, Image.ANTIALIAS)
left = (img.width - target_size[0]) / 2
top = (img.height - target_size[1]) / 2
right = (img.width + target_size[0]) / 2
bottom = (img.height + target_size[1]) / 2
return img.crop((left, top, right, bottom))

def render_img(
self, color, width, height, foreground_image=None, foreground_mask=None
self,
color,
width,
height,
foreground_image: Optional[torch.Tensor] = None,
foreground_mask: Optional[torch.Tensor] = None,
):
image = Image.new("RGBA", (width, height), color=color)
output = []
if foreground_image is not None:
if foreground_mask is None:
fg_images = tensor2pil(foreground_image)
for img in fg_images:
if image.size != img.size:
raise ValueError(
f"Dimension mismatch: image {image.size}, img {img.size}"
)
fg_images = tensor2pil(foreground_image)
fg_masks = [None] * len(
fg_images
) # Default to None for each foreground image

if foreground_mask is not None:
if foreground_image.size()[0] != foreground_mask.size()[0]:
raise ValueError(
"Foreground image and mask must have same batch size"
)
fg_masks = tensor2pil(foreground_mask)

if img.mode != "RGBA":
for fg_image, fg_mask in zip(fg_images, fg_masks):
# Resize and crop if dimensions mismatch
if fg_image.size != image.size:
fg_image = self.resize_and_crop(fg_image, image.size)
if fg_mask:
fg_mask = self.resize_and_crop(fg_mask, image.size)

if fg_mask:
output.append(
Image.composite(
fg_image.convert("RGBA"),
image,
fg_mask,
).convert("RGB")
)
else:
if fg_image.mode != "RGBA":
raise ValueError(
f"Foreground image must be in 'RGBA' mode when no mask is provided, got {img.mode}"
"Foreground image must be in 'RGBA' mode when no mask is provided, got {fg_image.mode}"
)
output.append(
Image.alpha_composite(image, fg_image).convert("RGB")
)

output.append(Image.alpha_composite(image, img).convert("RGB"))

elif foreground_image.size[0] != foreground_mask.size[0]:
raise ValueError("Foreground image and mask must have same batch size")
else:
fg_images = tensor2pil(foreground_image)
fg_masks = tensor2pil(foreground_mask)
output.extend(
Image.composite(
fg_image.convert("RGBA"),
image,
fg_mask,
).convert("RGB")
for fg_image, fg_mask in zip(fg_images, fg_masks)
)
elif foreground_mask is not None:
log.warn("Mask ignored because no foreground image is given")

Expand Down Expand Up @@ -577,7 +633,9 @@ def resize(
):
# Check if the tensor has the correct dimension
if len(image.shape) not in [3, 4]: # HxWxC or BxHxWxC
raise ValueError("Expected image tensor of shape (H, W, C) or (B, H, W, C)")
raise ValueError(
"Expected image tensor of shape (H, W, C) or (B, H, W, C)"
)

# Transpose to CxHxW or BxCxHxW for PyTorch
if len(image.shape) == 3:
Expand Down Expand Up @@ -689,7 +747,10 @@ def save_images(
subfolder,
filename_prefix,
) = folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
filename_prefix,
self.output_dir,
images[0].shape[1],
images[0].shape[0],
)
image_list = []
batch_counter = counter
Expand Down Expand Up @@ -719,10 +780,14 @@ def save_images(
file = f"{filename}_{counter:05}_.png"
grid = self.create_image_grid(image_list)
grid.save(
os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4
os.path.join(full_output_folder, file),
pnginfo=metadata,
compress_level=4,
)

results = [{"filename": file, "subfolder": subfolder, "type": self.type}]
results = [
{"filename": file, "subfolder": subfolder, "type": self.type}
]
return {"ui": {"images": results}}


Expand Down

0 comments on commit 30c4311

Please sign in to comment.