From c9d7f551591cf6c90bdbf261a911fd2e224a2420 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 Sep 2025 18:44:08 -0400 Subject: [PATCH] Update WanAnimateToVideo to more easily extend videos. --- comfy/ldm/wan/model_animate.py | 2 +- comfy_extras/nodes_wan.py | 59 +++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 542f5411013d..7c87835d467f 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -451,7 +451,7 @@ def __init__(self, def after_patch_embedding(self, x, pose_latents, face_pixel_values): if pose_latents is not None: pose_latents = self.pose_patch_embedding(pose_latents) - x[:, :, 1:] += pose_latents + x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1] if face_pixel_values is None: return x, None diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 4187a5619f38..3e5fef535a33 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1128,18 +1128,22 @@ def define_schema(cls): io.Image.Input("pose_video", optional=True), io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Image.Input("continue_motion", optional=True), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."), ], outputs=[ io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent"), io.Int.Output(display_name="trim_latent"), + io.Int.Output(display_name="trim_image"), + io.Int.Output(display_name="video_frame_offset"), ], is_experimental=True, ) @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput: + def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput: + trim_to_pose_video = False latent_length = ((length - 1) // 4) + 1 latent_width = width // 8 latent_height = height // 8 @@ -1152,35 +1156,60 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con concat_latent_image = vae.encode(image[:, :, :, :3]) mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) trim_latent += concat_latent_image.shape[2] + ref_motion_latent_length = 0 + + if continue_motion is None: + image = torch.ones((length, height, width, 3)) * 0.5 + else: + continue_motion = continue_motion[-continue_motion_max_frames:] + video_frame_offset -= continue_motion.shape[0] + video_frame_offset = max(0, video_frame_offset) + continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5 + image[:continue_motion.shape[0]] = continue_motion + ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1 if clip_vision_output is not None: positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - if face_video is not None: - face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0 - face_video = face_video.movedim(0, 1).unsqueeze(0) - positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) - negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] if pose_video is not None: pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + if not trim_to_pose_video: + if pose_video.shape[0] < length: + pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent}) negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent}) - if continue_motion is None: - image = torch.ones((length, height, width, 3)) * 0.5 - else: - continue_motion = continue_motion[-continue_motion_max_frames:] - continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) - image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5 - image[:continue_motion.shape[0]] = continue_motion + if trim_to_pose_video: + latent_length = pose_video_latent.shape[2] + length = latent_length * 4 - 3 + image = image[:length] + + if face_video is not None: + if face_video.shape[0] <= video_frame_offset: + face_video = None + else: + face_video = face_video[video_frame_offset:] + + if face_video is not None: + face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0 + face_video = face_video.movedim(0, 1).unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) + negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) if continue_motion is not None: - mask_refmotion[:, :, :((continue_motion.shape[0] - 1) // 4) + 1] = 0.0 + mask_refmotion[:, :, :ref_motion_latent_length] = 0.0 mask = torch.cat((mask, mask_refmotion), dim=2) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) @@ -1189,7 +1218,7 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device()) out_latent = {} out_latent["samples"] = latent - return io.NodeOutput(positive, negative, out_latent, trim_latent) + return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length) class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod