Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions comfy_extras/nodes_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import safetensors
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import Image, ImageDraw, ImageFont, ImageStat
from PIL.PngImagePlugin import PngInfo
import torch.utils.checkpoint
import tqdm
Expand Down Expand Up @@ -115,8 +115,31 @@ def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()


def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
def get_edge_color(img):
""" code borrowed from https://github.com/kijai/ComfyUI-KJNodes """
"""Sample edges and return dominant color"""
width, height = img.size
img = img.convert('RGB')

# Create 1-pixel high/wide images from edges
top = img.crop((0, 0, width, 1))
bottom = img.crop((0, height-1, width, height))
left = img.crop((0, 0, 1, height))
right = img.crop((width-1, 0, width, height))

# Combine edges into single image
edges = Image.new('RGB', (width*2 + height*2, 1))
edges.paste(top, (0, 0))
edges.paste(bottom, (width, 0))
edges.paste(left.resize((height, 1)), (width*2, 0))
edges.paste(right.resize((height, 1)), (width*2 + height, 0))

# Get median color
stat = ImageStat.Stat(edges)
median = tuple(map(int, stat.median))
return median

def load_and_process_images(image_files, input_dir, resize_method="None", width=None, height=None):
"""Utility function to load and process a list of images.

Args:
Expand All @@ -140,23 +163,61 @@ def load_and_process_images(image_files, input_dir, resize_method="None", w=None
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")

if w is None and h is None:
w, h = img.size[0], img.size[1]
if width is None and height is None:
width, height = img.size[0], img.size[1]

# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if img.size[0] != width or img.size[1] != height:
""" code partially borrowed from https://github.com/kijai/ComfyUI-KJNodes """
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
img = img.resize((width, height), Image.Resampling.LANCZOS)

img_width, img_height = img.size
aspect_ratio = img_width / img_height
target_ratio = width / height

if resize_method == "Crop":
# Calculate dimensions for center crop
if aspect_ratio > target_ratio:
# Image is wider - crop width
new_width = int(height * aspect_ratio)
img = img.resize((new_width, height), Image.Resampling.LANCZOS)
left = (new_width - width) // 2
img = img.crop((left, 0, left + width, height))
else:
# Image is taller - crop height
new_height = int(width / aspect_ratio)
img = img.resize((width, new_height), Image.Resampling.LANCZOS)
top = (new_height - height) // 2
img = img.crop((0, top, width, top + height))

elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
pad_color = get_edge_color(img)
# Calculate dimensions for padding
if aspect_ratio > target_ratio:
# Image is wider - pad height
new_height = int(width / aspect_ratio)
resized = img.resize((width, new_height), Image.Resampling.LANCZOS)
padding = (height - new_height) // 2
padded = Image.new('RGB', (width, height), pad_color)
padded.paste(resized, (0, padding))
img = padded
else:
# Image is taller - pad width
new_width = int(height * aspect_ratio)
resized = img.resize((new_width, height), Image.Resampling.LANCZOS)
padding = (width - new_width) // 2
padded = Image.new('RGB', (width, height), pad_color)
padded.paste(resized, (padding, 0))
img = padded

elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)

img_array = np.array(img).astype(np.float32) / 255.0
img_array = np.array(img).astype(np.float32)
img_array = img_array / np.float32(255.0)
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)

Expand Down