Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions comfy_extras/nodes_hypertile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#Taken from: https://github.com/tfernd/HyperTile/

import math
from typing_extensions import override
from einops import rearrange
# Use torch rng for consistency across generations
from torch import randint
from comfy_api.latest import ComfyExtension, io

def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value)
Expand All @@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:

return ns[idx]

class HyperTile:
class HyperTile(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
"scale_depth": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "model_patches/unet"

def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
def define_schema(cls):
return io.Schema(
node_id="HyperTile",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("tile_size", default=256, min=1, max=2048),
io.Int.Input("swap_size", default=2, min=1, max=128),
io.Int.Input("max_depth", default=0, min=0, max=10),
io.Boolean.Input("scale_depth", default=False),
],
outputs=[
io.Model.Output(),
],
)

@classmethod
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
latent_tile_size = max(32, tile_size) // 8
self.temp = None
temp = None

def hypertile_in(q, k, v, extra_options):
nonlocal temp
model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
Expand All @@ -58,14 +66,15 @@ def hypertile_in(q, k, v, extra_options):

if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
self.temp = (nh, nw, h, w)
temp = (nh, nw, h, w)
return q, k, v

return q, k, v
def hypertile_out(out, extra_options):
if self.temp is not None:
nh, nw, h, w = self.temp
self.temp = None
nonlocal temp
if temp is not None:
nh, nw, h, w = temp
temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
Expand All @@ -76,6 +85,14 @@ def hypertile_out(out, extra_options):
m.set_model_attn1_output_patch(hypertile_out)
return (m, )

NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile,
}

class HyperTileExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
HyperTile,
]


async def comfy_entrypoint() -> HyperTileExtension:
return HyperTileExtension()
Loading