1+ from numpy import arccos
12import nodes
23import node_helpers
34import torch
5+ import re
46import comfy .model_management
57
68
79class CLIPTextEncodeHunyuanDiT :
810 @classmethod
9- def INPUT_TYPES (s ):
11+ def INPUT_TYPES (cls ):
1012 return {"required" : {
1113 "clip" : ("CLIP" , ),
1214 "bert" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
@@ -23,6 +25,220 @@ def encode(self, clip, bert, mt5xl):
2325
2426 return (clip .encode_from_tokens_scheduled (tokens ), )
2527
28+ class MomentumBuffer :
29+ def __init__ (self , momentum : float ):
30+ self .momentum = momentum
31+ self .running_average = 0
32+
33+ def update (self , update_value : torch .Tensor ):
34+ new_average = self .momentum * self .running_average
35+ self .running_average = update_value + new_average
36+
37+ def normalized_guidance_apg (
38+ pred_cond : torch .Tensor ,
39+ pred_uncond : torch .Tensor ,
40+ guidance_scale : float ,
41+ momentum_buffer ,
42+ eta : float = 1.0 ,
43+ norm_threshold : float = 0.0 ,
44+ use_original_formulation : bool = False ,
45+ ):
46+ diff = pred_cond - pred_uncond
47+ dim = [- i for i in range (1 , len (diff .shape ))]
48+
49+ if momentum_buffer is not None :
50+ momentum_buffer .update (diff )
51+ diff = momentum_buffer .running_average
52+
53+ if norm_threshold > 0 :
54+ ones = torch .ones_like (diff )
55+ diff_norm = diff .norm (p = 2 , dim = dim , keepdim = True )
56+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
57+ diff = diff * scale_factor
58+
59+ v0 , v1 = diff .double (), pred_cond .double ()
60+ v1 = torch .nn .functional .normalize (v1 , dim = dim )
61+ v0_parallel = (v0 * v1 ).sum (dim = dim , keepdim = True ) * v1
62+ v0_orthogonal = v0 - v0_parallel
63+ diff_parallel , diff_orthogonal = v0_parallel .type_as (diff ), v0_orthogonal .type_as (diff )
64+
65+ normalized_update = diff_orthogonal + eta * diff_parallel
66+ pred = pred_cond if use_original_formulation else pred_uncond
67+ pred = pred + guidance_scale * normalized_update
68+
69+ return pred
70+
71+ class AdaptiveProjectedGuidance :
72+ def __init__ (
73+ self ,
74+ guidance_scale : float = 7.5 ,
75+ adaptive_projected_guidance_momentum = None ,
76+ adaptive_projected_guidance_rescale : float = 15.0 ,
77+ # eta: float = 1.0,
78+ eta : float = 0.0 ,
79+ guidance_rescale : float = 0.0 ,
80+ use_original_formulation : bool = False ,
81+ start : float = 0.0 ,
82+ stop : float = 1.0 ,
83+ ):
84+ super ().__init__ ()
85+
86+ self .guidance_scale = guidance_scale
87+ self .adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
88+ self .adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
89+ self .eta = eta
90+ self .guidance_rescale = guidance_rescale
91+ self .use_original_formulation = use_original_formulation
92+ self .momentum_buffer = None
93+
94+ def __call__ (self , pred_cond : torch .Tensor , pred_uncond = None , step = None ) -> torch .Tensor :
95+
96+ if step == 0 and self .adaptive_projected_guidance_momentum is not None :
97+ self .momentum_buffer = MomentumBuffer (self .adaptive_projected_guidance_momentum )
98+
99+ pred = normalized_guidance_apg (
100+ pred_cond ,
101+ pred_uncond ,
102+ self .guidance_scale ,
103+ self .momentum_buffer ,
104+ self .eta ,
105+ self .adaptive_projected_guidance_rescale ,
106+ self .use_original_formulation ,
107+ )
108+
109+ return pred
110+
111+ class HunyuanMixModeAPG :
112+
113+ @classmethod
114+ def INPUT_TYPES (s ):
115+ return {
116+ "required" : {
117+ "model" : ("MODEL" , ),
118+ "has_quoted_text" : ("HAS_QUOTED_TEXT" , ),
119+
120+ "guidance_scale" : ("FLOAT" , {"default" : 8.0 , "min" : 1.0 , "max" : 30.0 , "step" : 0.1 }),
121+
122+ "general_eta" : ("FLOAT" , {"default" : 0.0 , "min" : - 10.0 , "max" : 10.0 , "step" : 0.01 }),
123+ "general_norm_threshold" : ("FLOAT" , {"default" : 10.0 , "min" : 0.0 , "max" : 50.0 , "step" : 0.1 }),
124+ "general_momentum" : ("FLOAT" , {"default" : - 0.5 , "min" : - 5.0 , "max" : 1.0 , "step" : 0.01 }),
125+ "general_start_step" : ("INT" , {"default" : 10 , "min" : - 1 , "max" : 1000 }),
126+
127+ "ocr_eta" : ("FLOAT" , {"default" : 0.0 , "min" : - 10.0 , "max" : 10.0 , "step" : 0.01 }),
128+ "ocr_norm_threshold" : ("FLOAT" , {"default" : 10.0 , "min" : 0.0 , "max" : 50.0 , "step" : 0.1 }),
129+ "ocr_momentum" : ("FLOAT" , {"default" : - 0.5 , "min" : - 5.0 , "max" : 1.0 , "step" : 0.01 }),
130+ "ocr_start_step" : ("INT" , {"default" : 75 , "min" : - 1 , "max" : 1000 }),
131+
132+ }
133+ }
134+
135+ RETURN_TYPES = ("MODEL" ,)
136+ FUNCTION = "apply_mix_mode_apg"
137+ CATEGORY = "sampling/custom_sampling/hunyuan"
138+
139+
140+ @classmethod
141+ def IS_CHANGED (cls , * args , ** kwargs ):
142+ return True
143+
144+ def apply_mix_mode_apg (self , model , has_quoted_text , guidance_scale , general_eta , general_norm_threshold , general_momentum , general_start_step ,
145+ ocr_eta , ocr_norm_threshold , ocr_momentum , ocr_start_step ):
146+
147+ general_apg = AdaptiveProjectedGuidance (
148+ guidance_scale = guidance_scale ,
149+ eta = general_eta ,
150+ adaptive_projected_guidance_rescale = general_norm_threshold ,
151+ adaptive_projected_guidance_momentum = general_momentum
152+ )
153+
154+ ocr_apg = AdaptiveProjectedGuidance (
155+ eta = ocr_eta ,
156+ adaptive_projected_guidance_rescale = ocr_norm_threshold ,
157+ adaptive_projected_guidance_momentum = ocr_momentum
158+ )
159+
160+ current_step = {"step" : 0 }
161+
162+ def cfg_function (args ):
163+ cond = args ["cond" ]
164+ uncond = args ["uncond" ]
165+ cond_scale = args ["cond_scale" ]
166+
167+ step = current_step ["step" ]
168+ current_step ["step" ] += 1
169+
170+ if not has_quoted_text :
171+ if step > general_start_step :
172+ modified_cond = general_apg (cond , uncond , step ).to (torch .bfloat16 )
173+ return modified_cond
174+ else :
175+ if cond_scale > 1 :
176+ _ = general_apg (cond , uncond , step ) # track momentum
177+ return uncond + (cond - uncond ) * cond_scale
178+ else :
179+ if step > ocr_start_step :
180+ modified_cond = ocr_apg (cond , uncond , step )
181+ return modified_cond
182+ else :
183+ if cond_scale > 1 :
184+ _ = ocr_apg (cond , uncond , step )
185+ return uncond + (cond - uncond ) * cond_scale
186+
187+ return cond
188+
189+
190+ m = model .clone ()
191+ m .set_model_sampler_cfg_function (cfg_function , disable_cfg1_optimization = True )
192+ return (m ,)
193+
194+ class CLIPTextEncodeHunyuanDiTWithTextDetection :
195+
196+ @classmethod
197+ def INPUT_TYPES (cls ):
198+ return {"required" : {
199+ "clip" : ("CLIP" , ),
200+ "text" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
201+ }}
202+
203+ RETURN_TYPES = ("CONDITIONING" , "HAS_QUOTED_TEXT" )
204+ RETURN_NAMES = ("conditioning" , "has_quoted_text" )
205+ FUNCTION = "encode"
206+
207+ CATEGORY = "advanced/conditioning/hunyuan"
208+
209+ def detect_quoted_text (self , text ):
210+ """Detect quoted text in the prompt"""
211+ text_prompt_texts = []
212+
213+ # Patterns to match different quote styles
214+ pattern_quote_double = r'\"(.*?)\"'
215+ pattern_quote_chinese_single = r'‘(.*?)’'
216+ pattern_quote_chinese_double = r'“(.*?)”'
217+
218+ matches_quote_double = re .findall (pattern_quote_double , text )
219+ matches_quote_chinese_single = re .findall (pattern_quote_chinese_single , text )
220+ matches_quote_chinese_double = re .findall (pattern_quote_chinese_double , text )
221+
222+ text_prompt_texts .extend (matches_quote_double )
223+ text_prompt_texts .extend (matches_quote_chinese_single )
224+ text_prompt_texts .extend (matches_quote_chinese_double )
225+
226+ return len (text_prompt_texts ) > 0
227+
228+ def encode (self , clip , text ):
229+ tokens = clip .tokenize (text )
230+ has_quoted_text = self .detect_quoted_text (text )
231+
232+ conditioning = clip .encode_from_tokens_scheduled (tokens )
233+
234+ c = []
235+ for t in conditioning :
236+ n = [t [0 ], t [1 ].copy ()]
237+ n [1 ]['has_quoted_text' ] = has_quoted_text
238+ c .append (n )
239+
240+ return (c , has_quoted_text )
241+
26242class EmptyHunyuanLatentVideo :
27243 @classmethod
28244 def INPUT_TYPES (s ):
@@ -151,8 +367,16 @@ def execute(self, positive, negative, latent, noise_augmentation):
151367 return (positive , negative , out_latent )
152368
153369
370+
371+ NODE_DISPLAY_NAME_MAPPINGS = {
372+ "HunyuanMixModeAPG" : "Hunyuan Mix Mode APG" ,
373+ "HunyuanStepBasedAPG" : "Hunyuan Step Based APG" ,
374+ }
375+
154376NODE_CLASS_MAPPINGS = {
377+ "HunyuanMixModeAPG" : HunyuanMixModeAPG ,
155378 "CLIPTextEncodeHunyuanDiT" : CLIPTextEncodeHunyuanDiT ,
379+ "CLIPTextEncodeHunyuanDiTWithTextDetection" : CLIPTextEncodeHunyuanDiTWithTextDetection ,
156380 "TextEncodeHunyuanVideo_ImageToVideo" : TextEncodeHunyuanVideo_ImageToVideo ,
157381 "EmptyHunyuanLatentVideo" : EmptyHunyuanLatentVideo ,
158382 "HunyuanImageToVideo" : HunyuanImageToVideo ,
0 commit comments