From f3f87e3ba678ec1a9ad555212883bfecf0548508 Mon Sep 17 00:00:00 2001 From: KimbingNg Date: Mon, 15 Sep 2025 22:27:43 +0800 Subject: [PATCH 1/6] HunyuanImage2.1: Implement Hunyuan APG --- comfy_extras/nodes_hunyuan.py | 226 +++++++++++++++++++++++++++++++++- 1 file changed, 225 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index db398cdf14a6..eb3ff0287be8 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -1,12 +1,14 @@ +from numpy import arccos import nodes import node_helpers import torch +import re import comfy.model_management class CLIPTextEncodeHunyuanDiT: @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(cls): return {"required": { "clip": ("CLIP", ), "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), @@ -23,6 +25,220 @@ def encode(self, clip, bert, mt5xl): return (clip.encode_from_tokens_scheduled(tokens), ) +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + +def normalized_guidance_apg( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred + +class AdaptiveProjectedGuidance: + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum=None, + adaptive_projected_guidance_rescale: float = 15.0, + # eta: float = 1.0, + eta: float = 0.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__() + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor: + + if step == 0 and self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + + pred = normalized_guidance_apg( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + return pred + +class HunyuanMixModeAPG: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", ), + "has_quoted_text": ("HAS_QUOTED_TEXT", ), + + "guidance_scale": ("FLOAT", {"default": 9.0, "min": 1.0, "max": 30.0, "step": 0.1}), + + "general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), + "general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}), + + "ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), + "ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}), + + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_mix_mode_apg" + CATEGORY = "sampling/custom_sampling/hunyuan" + + + @classmethod + def IS_CHANGED(cls, model): + return True + + def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step, + ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step): + + general_apg = AdaptiveProjectedGuidance( + guidance_scale=guidance_scale, + eta=general_eta, + adaptive_projected_guidance_rescale=general_norm_threshold, + adaptive_projected_guidance_momentum=general_momentum + ) + + ocr_apg = AdaptiveProjectedGuidance( + eta=ocr_eta, + adaptive_projected_guidance_rescale=ocr_norm_threshold, + adaptive_projected_guidance_momentum=ocr_momentum + ) + + current_step = {"step": 0} + + def cfg_function(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + + step = current_step["step"] + current_step["step"] += 1 + + if not has_quoted_text: + if step > general_start_step: + modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16) + return modified_cond + else: + if cond_scale > 1: + _ = general_apg(cond, uncond, step) # track momentum + return uncond + (cond - uncond) * cond_scale + else: + if step > ocr_start_step: + modified_cond = ocr_apg(cond, uncond, step) + return modified_cond + else: + if cond_scale > 1: + _ = ocr_apg(cond, uncond, step) + return uncond + (cond - uncond) * cond_scale + + return cond + + + m = model.clone() + m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True) + return (m,) + +class CLIPTextEncodeHunyuanDiTWithTextDetection: + + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }} + + RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT") + RETURN_NAMES = ("conditioning", "has_quoted_text") + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning/hunyuan" + + def detect_quoted_text(self, text): + """Detect quoted text in the prompt""" + text_prompt_texts = [] + + # Patterns to match different quote styles + pattern_quote_double = r'\"(.*?)\"' + pattern_quote_chinese_single = r'‘(.*?)’' + pattern_quote_chinese_double = r'“(.*?)”' + + matches_quote_double = re.findall(pattern_quote_double, text) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) + + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + return len(text_prompt_texts) > 0 + + def encode(self, clip, text): + tokens = clip.tokenize(text) + has_quoted_text = self.detect_quoted_text(text) + + conditioning = clip.encode_from_tokens_scheduled(tokens) + + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['has_quoted_text'] = has_quoted_text + c.append(n) + + return (c, has_quoted_text) + class EmptyHunyuanLatentVideo: @classmethod def INPUT_TYPES(s): @@ -151,8 +367,16 @@ def execute(self, positive, negative, latent, noise_augmentation): return (positive, negative, out_latent) + +NODE_DISPLAY_NAME_MAPPINGS = { + "HunyuanMixModeAPG": "Hunyuan Mix Mode APG", + "HunyuanStepBasedAPG": "Hunyuan Step Based APG", +} + NODE_CLASS_MAPPINGS = { + "HunyuanMixModeAPG": HunyuanMixModeAPG, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, From 84f1acd1030df05c040b66098970aa63da3857eb Mon Sep 17 00:00:00 2001 From: KimbingNg Date: Tue, 16 Sep 2025 15:46:12 +0800 Subject: [PATCH 2/6] fix apg sigma, change default settings --- comfy_extras/nodes_hunyuan.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index eb3ff0287be8..a871fb4870b4 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -117,17 +117,17 @@ def INPUT_TYPES(s): "model": ("MODEL", ), "has_quoted_text": ("HAS_QUOTED_TEXT", ), - "guidance_scale": ("FLOAT", {"default": 9.0, "min": 1.0, "max": 30.0, "step": 0.1}), + "guidance_scale": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), "general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), "general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), "general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), - "general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}), + "general_start_step": ("INT", {"default": 5, "min": -1, "max": 1000}), "ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), "ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), "ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), - "ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}), + "ocr_start_step": ("INT", {"default": 38, "min": -1, "max": 1000}), } } @@ -160,6 +160,7 @@ def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_et current_step = {"step": 0} def cfg_function(args): + sigma = args["sigma"].to(torch.float32) cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] @@ -168,20 +169,20 @@ def cfg_function(args): current_step["step"] += 1 if not has_quoted_text: - if step > general_start_step: - modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16) - return modified_cond + if step >= general_start_step: + modified_cond = general_apg(cond / sigma, uncond / sigma, step) + return modified_cond * sigma else: if cond_scale > 1: - _ = general_apg(cond, uncond, step) # track momentum + _ = general_apg(cond / sigma, uncond / sigma, step) # track momentum return uncond + (cond - uncond) * cond_scale else: - if step > ocr_start_step: - modified_cond = ocr_apg(cond, uncond, step) - return modified_cond + if step >= ocr_start_step: + modified_cond = ocr_apg(cond / sigma, uncond / sigma, step) + return modified_cond * sigma else: if cond_scale > 1: - _ = ocr_apg(cond, uncond, step) + _ = ocr_apg(cond / sigma, uncond / sigma, step) # track momentum return uncond + (cond - uncond) * cond_scale return cond From 9492b98b41bede997533f90df34cbe2860b640aa Mon Sep 17 00:00:00 2001 From: KimbingNg Date: Tue, 16 Sep 2025 16:19:47 +0800 Subject: [PATCH 3/6] Add OUTER_SAMPLE wrapper --- comfy_extras/nodes_hunyuan.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index a871fb4870b4..6e66c85c5036 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -1,9 +1,9 @@ -from numpy import arccos import nodes import node_helpers import torch import re import comfy.model_management +import comfy.patcher_extension class CLIPTextEncodeHunyuanDiT: @@ -137,10 +137,6 @@ def INPUT_TYPES(s): CATEGORY = "sampling/custom_sampling/hunyuan" - @classmethod - def IS_CHANGED(cls, model): - return True - def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step, ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step): @@ -157,7 +153,13 @@ def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_et adaptive_projected_guidance_momentum=ocr_momentum ) - current_step = {"step": 0} + m = model.clone() + step_tracker = {"step": 0} + + def hunyuan_apg_outer_sample_wrapper(executor, *args, **kwargs): + step_tracker['step'] = 0 + return executor(*args, **kwargs) + def cfg_function(args): sigma = args["sigma"].to(torch.float32) @@ -165,8 +167,8 @@ def cfg_function(args): uncond = args["uncond"] cond_scale = args["cond_scale"] - step = current_step["step"] - current_step["step"] += 1 + step = step_tracker['step'] + step_tracker['step'] += 1 if not has_quoted_text: if step >= general_start_step: @@ -187,8 +189,7 @@ def cfg_function(args): return cond - - m = model.clone() + m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "hunyuan_apg", hunyuan_apg_outer_sample_wrapper) m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True) return (m,) From 6e6065b79391165d9445b34176c8c08cd5b3c1ea Mon Sep 17 00:00:00 2001 From: KimbingNg Date: Mon, 22 Sep 2025 12:25:43 +0800 Subject: [PATCH 4/6] Percentage APG & bug fixed. percentage step to switch apg/cfg. Fix shape mismatch error for batchsize != 1 --- comfy_extras/nodes_hunyuan.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 6e66c85c5036..8f910740b085 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -122,12 +122,12 @@ def INPUT_TYPES(s): "general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), "general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), "general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), - "general_start_step": ("INT", {"default": 5, "min": -1, "max": 1000}), + "general_start_percent": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of general APG."}), "ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), "ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), "ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), - "ocr_start_step": ("INT", {"default": 38, "min": -1, "max": 1000}), + "ocr_start_percent": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of OCR APG."}), } } @@ -137,8 +137,8 @@ def INPUT_TYPES(s): CATEGORY = "sampling/custom_sampling/hunyuan" - def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step, - ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step): + def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_percent, + ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_percent): general_apg = AdaptiveProjectedGuidance( guidance_scale=guidance_scale, @@ -154,15 +154,21 @@ def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_et ) m = model.clone() + + + model_sampling = m.model.model_sampling + general_start_t = model_sampling.percent_to_sigma(general_start_percent) + ocr_start_t = model_sampling.percent_to_sigma(ocr_start_percent) + step_tracker = {"step": 0} def hunyuan_apg_outer_sample_wrapper(executor, *args, **kwargs): step_tracker['step'] = 0 return executor(*args, **kwargs) - def cfg_function(args): sigma = args["sigma"].to(torch.float32) + sigma = sigma[:, None, None, None] cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] @@ -171,7 +177,7 @@ def cfg_function(args): step_tracker['step'] += 1 if not has_quoted_text: - if step >= general_start_step: + if sigma[0] <= general_start_t: modified_cond = general_apg(cond / sigma, uncond / sigma, step) return modified_cond * sigma else: @@ -179,7 +185,7 @@ def cfg_function(args): _ = general_apg(cond / sigma, uncond / sigma, step) # track momentum return uncond + (cond - uncond) * cond_scale else: - if step >= ocr_start_step: + if sigma[0] <= ocr_start_t: modified_cond = ocr_apg(cond / sigma, uncond / sigma, step) return modified_cond * sigma else: From f23bac81082fa9d4730fa1ee58b33d4c2d97ac1b Mon Sep 17 00:00:00 2001 From: KimbingNg Date: Tue, 23 Sep 2025 17:16:59 +0800 Subject: [PATCH 5/6] has_quoted_text -> boolean. remove unused dict entry --- comfy_extras/nodes_hunyuan.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 8f910740b085..fb1686188a9b 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -115,7 +115,7 @@ def INPUT_TYPES(s): return { "required": { "model": ("MODEL", ), - "has_quoted_text": ("HAS_QUOTED_TEXT", ), + "has_quoted_text": ("BOOLEAN", ), "guidance_scale": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), @@ -208,7 +208,7 @@ def INPUT_TYPES(cls): "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), }} - RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT") + RETURN_TYPES = ("CONDITIONING", "BOOLEAN") RETURN_NAMES = ("conditioning", "has_quoted_text") FUNCTION = "encode" @@ -239,13 +239,7 @@ def encode(self, clip, text): conditioning = clip.encode_from_tokens_scheduled(tokens) - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['has_quoted_text'] = has_quoted_text - c.append(n) - - return (c, has_quoted_text) + return (conditioning, has_quoted_text) class EmptyHunyuanLatentVideo: @classmethod From fd20999994bb287411c2b06e9b7c8fc8b82cdfe1 Mon Sep 17 00:00:00 2001 From: KimbingNg Date: Thu, 25 Sep 2025 17:55:15 +0800 Subject: [PATCH 6/6] remove step tracking code --- comfy_extras/nodes_hunyuan.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index fb1686188a9b..4d4d6b8d20b1 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -1,3 +1,4 @@ +import math import nodes import node_helpers import torch @@ -91,9 +92,9 @@ def __init__( self.use_original_formulation = use_original_formulation self.momentum_buffer = None - def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor: + def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, is_first_step=False) -> torch.Tensor: - if step == 0 and self.adaptive_projected_guidance_momentum is not None: + if is_first_step and self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) pred = normalized_guidance_apg( @@ -160,42 +161,36 @@ def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_et general_start_t = model_sampling.percent_to_sigma(general_start_percent) ocr_start_t = model_sampling.percent_to_sigma(ocr_start_percent) - step_tracker = {"step": 0} - - def hunyuan_apg_outer_sample_wrapper(executor, *args, **kwargs): - step_tracker['step'] = 0 - return executor(*args, **kwargs) def cfg_function(args): sigma = args["sigma"].to(torch.float32) - sigma = sigma[:, None, None, None] + is_first_step = math.isclose(sigma.item(), args['model_options']['transformer_options']['sample_sigmas'][0].item()) cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] - step = step_tracker['step'] - step_tracker['step'] += 1 + sigma = sigma[:, None, None, None] + if not has_quoted_text: if sigma[0] <= general_start_t: - modified_cond = general_apg(cond / sigma, uncond / sigma, step) + modified_cond = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) return modified_cond * sigma else: if cond_scale > 1: - _ = general_apg(cond / sigma, uncond / sigma, step) # track momentum + _ = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum return uncond + (cond - uncond) * cond_scale else: if sigma[0] <= ocr_start_t: - modified_cond = ocr_apg(cond / sigma, uncond / sigma, step) + modified_cond = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) return modified_cond * sigma else: if cond_scale > 1: - _ = ocr_apg(cond / sigma, uncond / sigma, step) # track momentum + _ = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum return uncond + (cond - uncond) * cond_scale return cond - m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "hunyuan_apg", hunyuan_apg_outer_sample_wrapper) m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True) return (m,)