Skip to content

Commit 192b74c

Browse files
committed
HunyuanImage2.1: Fix refiner template
1 parent 0836853 commit 192b74c

File tree

4 files changed

+113
-16
lines changed

4 files changed

+113
-16
lines changed

comfy/sd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ class CLIPType(Enum):
836836
OMNIGEN2 = 17
837837
QWEN_IMAGE = 18
838838
HUNYUAN_IMAGE = 19
839-
839+
HUNYUAN_IMAGE_REFINER = 20
840840

841841
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
842842
clip_data = []
@@ -995,6 +995,9 @@ class EmptyClass:
995995
if clip_type == CLIPType.HUNYUAN_IMAGE:
996996
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
997997
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
998+
elif clip_type == CLIPType.HUNYUAN_IMAGE_REFINER:
999+
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, refiner=True, **llama_detect(clip_data))
1000+
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageRefinerTokenizer
9981001
else:
9991002
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
10001003
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer

comfy/text_encoders/hunyuan_image.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from transformers import ByT5Tokenizer
55
import os
66
import re
7+
import torch
8+
import numbers
79

810
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
911
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -38,6 +40,13 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
3840
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
3941
return out
4042

43+
class HunyuanImageRefinerTokenizer(HunyuanImageTokenizer):
44+
def __init__(self, embedding_directory=None, tokenizer_data={}):
45+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
46+
self.llama_template = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
47+
48+
49+
4150
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
4251
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
4352
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
@@ -53,21 +62,45 @@ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model
5362
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
5463

5564

56-
class HunyuanImageTEModel(QwenImageTEModel):
65+
class HunyuanImageTEModel(sd1_clip.SD1ClipModel):
5766
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
58-
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
67+
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
5968

6069
if byt5:
6170
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
6271
else:
6372
self.byt5_small = None
6473

6574
def encode_token_weights(self, token_weight_pairs):
66-
cond, p, extra = super().encode_token_weights(token_weight_pairs)
75+
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
76+
tok_pairs = token_weight_pairs["qwen25_7b"][0]
77+
count_im_start = 0
78+
for i, v in enumerate(tok_pairs):
79+
elem = v[0]
80+
if not torch.is_tensor(elem):
81+
if isinstance(elem, numbers.Integral):
82+
if elem == 151644 and count_im_start < 2:
83+
template_end = i
84+
count_im_start += 1
85+
86+
if out.shape[1] > (template_end + 3):
87+
if tok_pairs[template_end + 1][0] == 872:
88+
if tok_pairs[template_end + 2][0] == 198:
89+
template_end += 3
90+
91+
out = out[:, template_end:]
92+
93+
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
94+
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
95+
extra.pop("attention_mask") # attention mask is useless if no masked elements
96+
# noqa: W293
97+
6798
if self.byt5_small is not None and "byt5" in token_weight_pairs:
68-
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
69-
extra["conditioning_byt5small"] = out[0]
70-
return cond, p, extra
99+
byt5_out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
100+
extra["conditioning_byt5small"] = byt5_out[0]
101+
return out, pooled, extra
102+
103+
71104

72105
def set_clip_options(self, options):
73106
super().set_clip_options(options)
@@ -84,14 +117,48 @@ def load_sd(self, sd):
84117
return self.byt5_small.load_sd(sd)
85118
else:
86119
return super().load_sd(sd)
120+
class HunyuanImageRefinerTEModel(sd1_clip.SD1ClipModel):
121+
def __init__(self, device="cpu", dtype=None, model_options={}):
122+
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
123+
124+
def encode_token_weights(self, token_weight_pairs):
125+
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
126+
tok_pairs = token_weight_pairs["qwen25_7b"][0]
127+
for i, v in enumerate(tok_pairs):
128+
elem = v[0]
129+
if not torch.is_tensor(elem):
130+
if isinstance(elem, numbers.Integral):
131+
if elem == 6171:
132+
template_end = i
133+
break
134+
135+
out = out[:, template_end-1:]
136+
137+
extra["attention_mask"] = extra["attention_mask"][:, template_end-1:]
138+
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
139+
extra.pop("attention_mask") # attention mask is useless if no masked elements
140+
141+
return out, pooled, extra
142+
143+
144+
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None, refiner=False):
145+
class HunyuanImageTEModel_(HunyuanImageTEModel):
87146

88-
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
89-
class QwenImageTEModel_(HunyuanImageTEModel):
90147
def __init__(self, device="cpu", dtype=None, model_options={}):
91148
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
92149
model_options = model_options.copy()
93150
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
94151
if dtype_llama is not None:
95152
dtype = dtype_llama
96153
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
97-
return QwenImageTEModel_
154+
class HunyuanImageTEModel_refiner(HunyuanImageRefinerTEModel):
155+
def __init__(self, device="cpu", dtype=None, model_options={}):
156+
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
157+
model_options = model_options.copy()
158+
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
159+
if dtype_llama is not None:
160+
dtype = dtype_llama
161+
assert refiner, "refiner must be True"
162+
assert not byt5, "byt5 must be False"
163+
super().__init__(device=device, dtype=dtype, model_options=model_options)
164+
return HunyuanImageTEModel_refiner if refiner else HunyuanImageTEModel_

comfy_extras/nodes_hunyuan.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,8 @@ def INPUT_TYPES(cls):
199199
"clip": ("CLIP", ),
200200
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
201201
}}
202-
203-
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
204-
RETURN_NAMES = ("conditioning", "has_quoted_text")
202+
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT", "STRING")
203+
RETURN_NAMES = ("conditioning", "has_quoted_text", "text")
205204
FUNCTION = "encode"
206205

207206
CATEGORY = "advanced/conditioning/hunyuan"
@@ -237,7 +236,35 @@ def encode(self, clip, text):
237236
n[1]['has_quoted_text'] = has_quoted_text
238237
c.append(n)
239238

240-
return (c, has_quoted_text)
239+
return (c, has_quoted_text, text)
240+
241+
class CLIPTextEncodeHunyuanImageRefiner:
242+
@classmethod
243+
def INPUT_TYPES(cls):
244+
return {"required": {
245+
"clip": ("CLIP", ),
246+
"text": ("STRING", ),
247+
}}
248+
RETURN_TYPES = ("CONDITIONING",)
249+
RETURN_NAMES = ("conditioning",)
250+
FUNCTION = "encode"
251+
252+
CATEGORY = "advanced/conditioning/hunyuan"
253+
254+
255+
def encode(self, clip, text):
256+
tokens = clip.tokenize(text)
257+
258+
conditioning = clip.encode_from_tokens_scheduled(tokens)
259+
260+
c = []
261+
for t in conditioning:
262+
n = [t[0], t[1].copy()]
263+
c.append(n)
264+
265+
return (c, )
266+
267+
241268

242269
class EmptyHunyuanLatentVideo:
243270
@classmethod
@@ -370,13 +397,13 @@ def execute(self, positive, negative, latent, noise_augmentation):
370397

371398
NODE_DISPLAY_NAME_MAPPINGS = {
372399
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
373-
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
374400
}
375401

376402
NODE_CLASS_MAPPINGS = {
377403
"HunyuanMixModeAPG": HunyuanMixModeAPG,
378404
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
379405
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
406+
"CLIPTextEncodeHunyuanImageRefiner": CLIPTextEncodeHunyuanImageRefiner,
380407
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
381408
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
382409
"HunyuanImageToVideo": HunyuanImageToVideo,

nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ class CLIPLoader:
929929
@classmethod
930930
def INPUT_TYPES(s):
931931
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
932-
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
932+
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image","hunyuan_image_refiner"], ),
933933
},
934934
"optional": {
935935
"device": (["default", "cpu"], {"advanced": True}),

0 commit comments

Comments
 (0)