Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions comfy/ldm/lumina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def patchify_and_embed(

if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)

cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
Expand All @@ -525,7 +525,7 @@ def patchify_and_embed(

if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))

freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
Expand Down
102 changes: 1 addition & 101 deletions comfy_extras/nodes_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import math
import json

import numpy as np
Expand Down Expand Up @@ -624,79 +623,6 @@ def _group_process(cls, texts, **kwargs):
# ========== Image Transform Nodes ==========


class ResizeImagesToSameSizeNode(ImageProcessingNode):
node_id = "ResizeImagesToSameSize"
display_name = "Resize Images to Same Size"
description = "Resize all images to the same width and height."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
io.Combo.Input(
"mode",
options=["stretch", "crop_center", "pad"],
default="stretch",
tooltip="Resize mode.",
),
]

@classmethod
def _process(cls, image, width, height, mode):
img = tensor_to_pil(image)

if mode == "stretch":
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "crop_center":
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
if img.width != width or img.height != height:
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "pad":
img.thumbnail((width, height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (width, height), (0, 0, 0))
paste_x = (width - img.width) // 2
paste_y = (height - img.height) // 2
new_img.paste(img, (paste_x, paste_y))
img = new_img

return pil_to_tensor(img)


class ResizeImagesToPixelCountNode(ImageProcessingNode):
node_id = "ResizeImagesToPixelCount"
display_name = "Resize Images to Pixel Count"
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"pixel_count",
default=512 * 512,
min=1,
max=8192 * 8192,
tooltip="Target pixel count.",
),
io.Int.Input(
"steps",
default=64,
min=1,
max=128,
tooltip="The stepping for resize width/height.",
),
]

@classmethod
def _process(cls, image, pixel_count, steps):
img = tensor_to_pil(image)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(w * pixel_count_ratio / steps) * steps
new_h = int(h * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)


class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge"
Expand Down Expand Up @@ -801,29 +727,6 @@ def _process(cls, image, width, height, seed):
return pil_to_tensor(img)


class FlipImagesNode(ImageProcessingNode):
node_id = "FlipImages"
display_name = "Flip Images"
description = "Flip all images horizontally or vertically."
extra_inputs = [
io.Combo.Input(
"direction",
options=["horizontal", "vertical"],
default="horizontal",
tooltip="Flip direction.",
),
]

@classmethod
def _process(cls, image, direction):
img = tensor_to_pil(image)
if direction == "horizontal":
img = img.transpose(Image.FLIP_LEFT_RIGHT)
else:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
return pil_to_tensor(img)


class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
display_name = "Normalize Images"
Expand Down Expand Up @@ -1470,7 +1373,7 @@ def execute(cls, folder_name):
shard_path = os.path.join(dataset_dir, shard_file)

with open(shard_path, "rb") as f:
shard_data = torch.load(f)
shard_data = torch.load(f, weights_only=True)

all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
Expand All @@ -1496,13 +1399,10 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode,
# Image transform nodes
ResizeImagesToSameSizeNode,
ResizeImagesToPixelCountNode,
ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode,
CenterCropImagesNode,
RandomCropImagesNode,
FlipImagesNode,
NormalizeImagesNode,
AdjustBrightnessNode,
AdjustContrastNode,
Expand Down