Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat):

latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]

class HunyuanImage21Refiner(LatentFormat):
latent_channels = 64
latent_dimensions = 3
scale_factor = 1.03682

class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
Expand Down
11 changes: 6 additions & 5 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def forward_orig(
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
disable_time_r=False,
control=None,
transformer_options={},
) -> Tensor:
Expand All @@ -288,7 +289,7 @@ def forward_orig(
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))

if self.time_r_in is not None:
if (self.time_r_in is not None) and (not disable_time_r):
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
Expand Down Expand Up @@ -428,20 +429,20 @@ def img_ids_2d(self, x):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)

def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)

def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out
268 changes: 268 additions & 0 deletions comfy/ldm/hunyuan_video/vae_refiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init

class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
shape = (dim, 1, 1, 1)
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.empty(shape))

def forward(self, x):
return F.normalize(x, dim=1) * self.scale * self.gamma

class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)

self.tds = tds
self.gs = fct * ic // oc

def forward(self, x):
r1 = 2 if self.tds else 1
h = self.conv(x)

if self.tds:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1)

hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)

h = torch.cat([hf, hn], dim=2)

xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)

xn = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)

b, ci, frms, ht, wd = x.shape
nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)

return h + sc


class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)

self.tus = tus
self.rp = fct * oc // ic

def forward(self, x):
r1 = 2 if self.tus else 1
h = self.conv(x)

if self.tus:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2]

hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)

h = torch.cat([hf, hn], dim=2)

xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
b, c, f, ht, wd = xf.shape
nc = c // (2 * 2)
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)

xn = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)

sc = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = sc.shape
nc = c // (r1 * 2 * 2)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)

return h + sc

class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)

self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()

for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
ch = nxt
self.down.append(stage)

self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)

self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)

self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()

def forward(self, x):
x = x.unsqueeze(2)
x = self.conv_in(x)

for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)

x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))

b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2)

out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]

out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out

class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks

ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3)

self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)

self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = (ffactor_temporal >> 1).bit_length()

for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
ch = nxt
self.up.append(stage)

self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)

def forward(self, z):
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]

x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))

for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)

return self.conv_out(F.silu(self.norm_out(x)))
6 changes: 6 additions & 0 deletions comfy/ldm/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
z = posterior.mode()
return z, None

class EmptyRegularizer(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, None

class AbstractAutoencoder(torch.nn.Module):
"""
Expand Down
10 changes: 5 additions & 5 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def forward(self, x):

class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut

self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.norm1 = norm_op(in_channels)
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
Expand All @@ -162,7 +162,7 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.norm2 = norm_op(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = conv_op(out_channels,
out_channels,
Expand Down Expand Up @@ -305,11 +305,11 @@ def vae_attention():
return normal_attention

class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels

self.norm = Normalize(in_channels)
self.norm = norm_op(in_channels)
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,
Expand Down
Loading
Loading