diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index db398cdf14a6..4d4d6b8d20b1 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -1,12 +1,15 @@ +import math import nodes import node_helpers import torch +import re import comfy.model_management +import comfy.patcher_extension class CLIPTextEncodeHunyuanDiT: @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(cls): return {"required": { "clip": ("CLIP", ), "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), @@ -23,6 +26,216 @@ def encode(self, clip, bert, mt5xl): return (clip.encode_from_tokens_scheduled(tokens), ) +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + +def normalized_guidance_apg( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred + +class AdaptiveProjectedGuidance: + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum=None, + adaptive_projected_guidance_rescale: float = 15.0, + # eta: float = 1.0, + eta: float = 0.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__() + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, is_first_step=False) -> torch.Tensor: + + if is_first_step and self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + + pred = normalized_guidance_apg( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + return pred + +class HunyuanMixModeAPG: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", ), + "has_quoted_text": ("BOOLEAN", ), + + "guidance_scale": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), + + "general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), + "general_start_percent": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of general APG."}), + + "ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), + "ocr_start_percent": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of OCR APG."}), + + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_mix_mode_apg" + CATEGORY = "sampling/custom_sampling/hunyuan" + + + def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_percent, + ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_percent): + + general_apg = AdaptiveProjectedGuidance( + guidance_scale=guidance_scale, + eta=general_eta, + adaptive_projected_guidance_rescale=general_norm_threshold, + adaptive_projected_guidance_momentum=general_momentum + ) + + ocr_apg = AdaptiveProjectedGuidance( + eta=ocr_eta, + adaptive_projected_guidance_rescale=ocr_norm_threshold, + adaptive_projected_guidance_momentum=ocr_momentum + ) + + m = model.clone() + + + model_sampling = m.model.model_sampling + general_start_t = model_sampling.percent_to_sigma(general_start_percent) + ocr_start_t = model_sampling.percent_to_sigma(ocr_start_percent) + + + def cfg_function(args): + sigma = args["sigma"].to(torch.float32) + is_first_step = math.isclose(sigma.item(), args['model_options']['transformer_options']['sample_sigmas'][0].item()) + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + + sigma = sigma[:, None, None, None] + + + if not has_quoted_text: + if sigma[0] <= general_start_t: + modified_cond = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) + return modified_cond * sigma + else: + if cond_scale > 1: + _ = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum + return uncond + (cond - uncond) * cond_scale + else: + if sigma[0] <= ocr_start_t: + modified_cond = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) + return modified_cond * sigma + else: + if cond_scale > 1: + _ = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum + return uncond + (cond - uncond) * cond_scale + + return cond + + m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True) + return (m,) + +class CLIPTextEncodeHunyuanDiTWithTextDetection: + + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }} + + RETURN_TYPES = ("CONDITIONING", "BOOLEAN") + RETURN_NAMES = ("conditioning", "has_quoted_text") + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning/hunyuan" + + def detect_quoted_text(self, text): + """Detect quoted text in the prompt""" + text_prompt_texts = [] + + # Patterns to match different quote styles + pattern_quote_double = r'\"(.*?)\"' + pattern_quote_chinese_single = r'‘(.*?)’' + pattern_quote_chinese_double = r'“(.*?)”' + + matches_quote_double = re.findall(pattern_quote_double, text) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) + + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + return len(text_prompt_texts) > 0 + + def encode(self, clip, text): + tokens = clip.tokenize(text) + has_quoted_text = self.detect_quoted_text(text) + + conditioning = clip.encode_from_tokens_scheduled(tokens) + + return (conditioning, has_quoted_text) + class EmptyHunyuanLatentVideo: @classmethod def INPUT_TYPES(s): @@ -151,8 +364,16 @@ def execute(self, positive, negative, latent, noise_augmentation): return (positive, negative, out_latent) + +NODE_DISPLAY_NAME_MAPPINGS = { + "HunyuanMixModeAPG": "Hunyuan Mix Mode APG", + "HunyuanStepBasedAPG": "Hunyuan Step Based APG", +} + NODE_CLASS_MAPPINGS = { + "HunyuanMixModeAPG": HunyuanMixModeAPG, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo,