diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 0550b2f9bc44..46ef21c95cf1 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -41,6 +41,7 @@ def encode_audio(self, audio, sample_rate): outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers + outputs["audio_samples"] = audio.shape[2] return outputs diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 67dcf8f1e8bb..b3b7da5d5b4d 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -34,7 +34,9 @@ def __init__(self, num_heads, window_size=(-1, -1), qk_norm=True, - eps=1e-6, operation_settings={}): + eps=1e-6, + kv_dim=None, + operation_settings={}): assert dim % num_heads == 0 super().__init__() self.dim = dim @@ -43,11 +45,13 @@ def __init__(self, self.window_size = window_size self.qk_norm = qk_norm self.eps = eps + if kv_dim is None: + kv_dim = dim # layers self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() @@ -402,6 +406,7 @@ def __init__(self, eps=1e-6, flf_pos_embed_token_number=None, in_dim_ref_conv=None, + wan_attn_block_class=WanAttentionBlock, image_model=None, device=None, dtype=None, @@ -479,8 +484,8 @@ def __init__(self, # blocks cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ - WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) for _ in range(num_layers) ]) @@ -1325,3 +1330,247 @@ def block_wrap(args): # unpatchify x = self.unpatchify(x, grid_sizes) return x + + +class WanT2VCrossAttentionGather(WanSelfAttention): + + def forward(self, x, context, transformer_options={}, **kwargs): + r""" + Args: + x(Tensor): Shape [B, L1, C] - video tokens + context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(context)) + v = self.v(context) + + # Handle audio temporal structure (16 tokens per frame) + k = k.reshape(-1, 16, n, d).transpose(1, 2) + v = v.reshape(-1, 16, n, d).transpose(1, 2) + + # Handle video spatial structure + q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2) + + x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) + + x = x.transpose(1, 2).view(b, -1, n, d).flatten(2) + x = self.o(x) + return x + + +class AudioCrossAttentionWrapper(nn.Module): + def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}): + super().__init__() + + self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings) + self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x, audio, transformer_options={}): + x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options) + return x + + +class WanAttentionBlockAudio(WanAttentionBlock): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings) + self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings) + + def forward( + self, + x, + e, + freqs, + context, + context_img_len=257, + audio=None, + transformer_options={}, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + # assert e.dtype == torch.float32 + + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + # assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, transformer_options=transformer_options) + + x = torch.addcmul(x, y, repeat_e(e[2], x)) + + # cross-attention & ffn + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if audio is not None: + x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + +class DummyAdapterLayer(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + +class AudioProjModel(nn.Module): + def __init__( + self, + seq_len=5, + blocks=13, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=1536, + context_tokens=16, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels. + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device)) + + self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device)) + + def forward(self, audio_embeds): + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds)) + audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds)) + + context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim) + + context_tokens = self.audio_proj_glob_norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class HumoWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='humo', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + audio_token_num=16, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) + + def forward_orig( + self, + x, + t, + context, + freqs=None, + audio_embed=None, + reference_latent=None, + transformer_options={}, + **kwargs, + ): + bs, _, time, height, width = x.shape + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + if reference_latent is not None: + ref = self.patch_embedding(reference_latent.float()).to(x.dtype) + ref = ref.flatten(2).transpose(1, 2) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype) + x = torch.cat([x, ref], dim=1) + freqs = torch.cat([freqs, freqs_ref], dim=1) + del ref, freqs_ref + + # context + context = self.text_embedding(context) + context_img_len = None + + if audio_embed is not None: + audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2) + else: + audio = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 252dfcf69d6d..cf99035da32e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1213,6 +1213,23 @@ def extra_conds(self, **kwargs): out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out +class WAN21_HuMo(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + audio_embed = kwargs.get("audio_embed", None) + if audio_embed is not None: + out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + return out + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 03d44f65e2a9..72621bed650f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -402,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "camera_2.2" elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "s2v" + elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "humo" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 557902d110b6..213b5b92c826 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1073,6 +1073,16 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out +class WAN21_HuMo(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "humo", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_HuMo(self, image_to_video=False, device=device) + return out + class WAN22_S2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1351,6 +1361,6 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 4f73369f5b03..0b8b55813cf2 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1015,6 +1015,103 @@ def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, return io.NodeOutput(positive, negative, out_latent) +def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) + + return audio_emb_wind, ed - audio_shift + + +class WanHuMoImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanHuMoImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None) -> io.NodeOutput: + latent_t = ((length - 1) // 4) + 1 + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + else: + zero_latent = torch.zeros([batch_size, 16, 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [zero_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [zero_latent]}, append=True) + + if audio_encoder_output is not None: + audio_emb = torch.stack(audio_encoder_output["encoded_audio_all_layers"], dim=2) + audio_len = audio_encoder_output["audio_samples"] // 640 + audio_emb = audio_emb[:, :audio_len * 2] + + feat0 = linear_interpolation(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation(audio_emb[:, :, 32], 50, 25) + audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] + audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0) + + # pad for ref latent + zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype) + audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) + + audio_emb = audio_emb.unsqueeze(0) + audio_emb_neg = torch.zeros_like(audio_emb) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_emb_neg}) + else: + zero_audio = torch.zeros([batch_size, latent_t + 1, 8, 5, 1280], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": zero_audio}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": zero_audio}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): @@ -1075,6 +1172,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: WanPhantomSubjectToVideo, WanSoundImageToVideo, WanSoundImageToVideoExtend, + WanHuMoImageToVideo, Wan22ImageToVideoLatent, ]