Skip to content

Commit 0836853

Browse files
committed
HunyuanImage2.1: Implement Hunyuan APG
1 parent 4f1f26a commit 0836853

File tree

1 file changed

+225
-1
lines changed

1 file changed

+225
-1
lines changed

comfy_extras/nodes_hunyuan.py

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
from numpy import arccos
12
import nodes
23
import node_helpers
34
import torch
5+
import re
46
import comfy.model_management
57

68

79
class CLIPTextEncodeHunyuanDiT:
810
@classmethod
9-
def INPUT_TYPES(s):
11+
def INPUT_TYPES(cls):
1012
return {"required": {
1113
"clip": ("CLIP", ),
1214
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
@@ -23,6 +25,220 @@ def encode(self, clip, bert, mt5xl):
2325

2426
return (clip.encode_from_tokens_scheduled(tokens), )
2527

28+
class MomentumBuffer:
29+
def __init__(self, momentum: float):
30+
self.momentum = momentum
31+
self.running_average = 0
32+
33+
def update(self, update_value: torch.Tensor):
34+
new_average = self.momentum * self.running_average
35+
self.running_average = update_value + new_average
36+
37+
def normalized_guidance_apg(
38+
pred_cond: torch.Tensor,
39+
pred_uncond: torch.Tensor,
40+
guidance_scale: float,
41+
momentum_buffer,
42+
eta: float = 1.0,
43+
norm_threshold: float = 0.0,
44+
use_original_formulation: bool = False,
45+
):
46+
diff = pred_cond - pred_uncond
47+
dim = [-i for i in range(1, len(diff.shape))]
48+
49+
if momentum_buffer is not None:
50+
momentum_buffer.update(diff)
51+
diff = momentum_buffer.running_average
52+
53+
if norm_threshold > 0:
54+
ones = torch.ones_like(diff)
55+
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
56+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
57+
diff = diff * scale_factor
58+
59+
v0, v1 = diff.double(), pred_cond.double()
60+
v1 = torch.nn.functional.normalize(v1, dim=dim)
61+
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
62+
v0_orthogonal = v0 - v0_parallel
63+
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
64+
65+
normalized_update = diff_orthogonal + eta * diff_parallel
66+
pred = pred_cond if use_original_formulation else pred_uncond
67+
pred = pred + guidance_scale * normalized_update
68+
69+
return pred
70+
71+
class AdaptiveProjectedGuidance:
72+
def __init__(
73+
self,
74+
guidance_scale: float = 7.5,
75+
adaptive_projected_guidance_momentum=None,
76+
adaptive_projected_guidance_rescale: float = 15.0,
77+
# eta: float = 1.0,
78+
eta: float = 0.0,
79+
guidance_rescale: float = 0.0,
80+
use_original_formulation: bool = False,
81+
start: float = 0.0,
82+
stop: float = 1.0,
83+
):
84+
super().__init__()
85+
86+
self.guidance_scale = guidance_scale
87+
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
88+
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
89+
self.eta = eta
90+
self.guidance_rescale = guidance_rescale
91+
self.use_original_formulation = use_original_formulation
92+
self.momentum_buffer = None
93+
94+
def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor:
95+
96+
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
97+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
98+
99+
pred = normalized_guidance_apg(
100+
pred_cond,
101+
pred_uncond,
102+
self.guidance_scale,
103+
self.momentum_buffer,
104+
self.eta,
105+
self.adaptive_projected_guidance_rescale,
106+
self.use_original_formulation,
107+
)
108+
109+
return pred
110+
111+
class HunyuanMixModeAPG:
112+
113+
@classmethod
114+
def INPUT_TYPES(s):
115+
return {
116+
"required": {
117+
"model": ("MODEL", ),
118+
"has_quoted_text": ("HAS_QUOTED_TEXT", ),
119+
120+
"guidance_scale": ("FLOAT", {"default": 8.0, "min": 1.0, "max": 30.0, "step": 0.1}),
121+
122+
"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
123+
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
124+
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
125+
"general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}),
126+
127+
"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
128+
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
129+
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
130+
"ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}),
131+
132+
}
133+
}
134+
135+
RETURN_TYPES = ("MODEL",)
136+
FUNCTION = "apply_mix_mode_apg"
137+
CATEGORY = "sampling/custom_sampling/hunyuan"
138+
139+
140+
@classmethod
141+
def IS_CHANGED(cls, *args, **kwargs):
142+
return True
143+
144+
def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step,
145+
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step):
146+
147+
general_apg = AdaptiveProjectedGuidance(
148+
guidance_scale=guidance_scale,
149+
eta=general_eta,
150+
adaptive_projected_guidance_rescale=general_norm_threshold,
151+
adaptive_projected_guidance_momentum=general_momentum
152+
)
153+
154+
ocr_apg = AdaptiveProjectedGuidance(
155+
eta=ocr_eta,
156+
adaptive_projected_guidance_rescale=ocr_norm_threshold,
157+
adaptive_projected_guidance_momentum=ocr_momentum
158+
)
159+
160+
current_step = {"step": 0}
161+
162+
def cfg_function(args):
163+
cond = args["cond"]
164+
uncond = args["uncond"]
165+
cond_scale = args["cond_scale"]
166+
167+
step = current_step["step"]
168+
current_step["step"] += 1
169+
170+
if not has_quoted_text:
171+
if step > general_start_step:
172+
modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16)
173+
return modified_cond
174+
else:
175+
if cond_scale > 1:
176+
_ = general_apg(cond, uncond, step) # track momentum
177+
return uncond + (cond - uncond) * cond_scale
178+
else:
179+
if step > ocr_start_step:
180+
modified_cond = ocr_apg(cond, uncond, step)
181+
return modified_cond
182+
else:
183+
if cond_scale > 1:
184+
_ = ocr_apg(cond, uncond, step)
185+
return uncond + (cond - uncond) * cond_scale
186+
187+
return cond
188+
189+
190+
m = model.clone()
191+
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
192+
return (m,)
193+
194+
class CLIPTextEncodeHunyuanDiTWithTextDetection:
195+
196+
@classmethod
197+
def INPUT_TYPES(cls):
198+
return {"required": {
199+
"clip": ("CLIP", ),
200+
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
201+
}}
202+
203+
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
204+
RETURN_NAMES = ("conditioning", "has_quoted_text")
205+
FUNCTION = "encode"
206+
207+
CATEGORY = "advanced/conditioning/hunyuan"
208+
209+
def detect_quoted_text(self, text):
210+
"""Detect quoted text in the prompt"""
211+
text_prompt_texts = []
212+
213+
# Patterns to match different quote styles
214+
pattern_quote_double = r'\"(.*?)\"'
215+
pattern_quote_chinese_single = r'‘(.*?)’'
216+
pattern_quote_chinese_double = r'“(.*?)”'
217+
218+
matches_quote_double = re.findall(pattern_quote_double, text)
219+
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
220+
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
221+
222+
text_prompt_texts.extend(matches_quote_double)
223+
text_prompt_texts.extend(matches_quote_chinese_single)
224+
text_prompt_texts.extend(matches_quote_chinese_double)
225+
226+
return len(text_prompt_texts) > 0
227+
228+
def encode(self, clip, text):
229+
tokens = clip.tokenize(text)
230+
has_quoted_text = self.detect_quoted_text(text)
231+
232+
conditioning = clip.encode_from_tokens_scheduled(tokens)
233+
234+
c = []
235+
for t in conditioning:
236+
n = [t[0], t[1].copy()]
237+
n[1]['has_quoted_text'] = has_quoted_text
238+
c.append(n)
239+
240+
return (c, has_quoted_text)
241+
26242
class EmptyHunyuanLatentVideo:
27243
@classmethod
28244
def INPUT_TYPES(s):
@@ -151,8 +367,16 @@ def execute(self, positive, negative, latent, noise_augmentation):
151367
return (positive, negative, out_latent)
152368

153369

370+
371+
NODE_DISPLAY_NAME_MAPPINGS = {
372+
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
373+
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
374+
}
375+
154376
NODE_CLASS_MAPPINGS = {
377+
"HunyuanMixModeAPG": HunyuanMixModeAPG,
155378
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
379+
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
156380
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
157381
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
158382
"HunyuanImageToVideo": HunyuanImageToVideo,

0 commit comments

Comments
 (0)