44from transformers import ByT5Tokenizer
55import os
66import re
7+ import torch
8+ import numbers
79
810class ByT5SmallTokenizer (sd1_clip .SDTokenizer ):
911 def __init__ (self , embedding_directory = None , tokenizer_data = {}):
@@ -38,6 +40,13 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
3840 out ['byt5' ] = self .byt5 .tokenize_with_weights ('' .join (map (lambda a : 'Text "{}". ' .format (a ), text_prompt_texts )), return_word_ids , ** kwargs )
3941 return out
4042
43+ class HunyuanImageRefinerTokenizer (HunyuanImageTokenizer ):
44+ def __init__ (self , embedding_directory = None , tokenizer_data = {}):
45+ super ().__init__ (embedding_directory = embedding_directory , tokenizer_data = tokenizer_data )
46+ self .llama_template = "<|start_header_id|>system<|end_header_id|>\n \n Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|>\n <|start_header_id|>user<|end_header_id|>\n \n {}<|eot_id|>"
47+
48+
49+
4150class Qwen25_7BVLIModel (sd1_clip .SDClipModel ):
4251 def __init__ (self , device = "cpu" , layer = "hidden" , layer_idx = - 3 , dtype = None , attention_mask = True , model_options = {}):
4352 llama_scaled_fp8 = model_options .get ("qwen_scaled_fp8" , None )
@@ -53,21 +62,45 @@ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model
5362 super ().__init__ (device = device , layer = layer , layer_idx = layer_idx , textmodel_json_config = textmodel_json_config , dtype = dtype , model_options = model_options , special_tokens = {"end" : 1 , "pad" : 0 }, model_class = comfy .text_encoders .t5 .T5 , enable_attention_masks = True , zero_out_masked = True )
5463
5564
56- class HunyuanImageTEModel (QwenImageTEModel ):
65+ class HunyuanImageTEModel (sd1_clip . SD1ClipModel ):
5766 def __init__ (self , byt5 = True , device = "cpu" , dtype = None , model_options = {}):
58- super (QwenImageTEModel , self ).__init__ (device = device , dtype = dtype , name = "qwen25_7b" , clip_model = Qwen25_7BVLIModel , model_options = model_options )
67+ super ().__init__ (device = device , dtype = dtype , name = "qwen25_7b" , clip_model = Qwen25_7BVLIModel , model_options = model_options )
5968
6069 if byt5 :
6170 self .byt5_small = ByT5SmallModel (device = device , dtype = dtype , model_options = model_options )
6271 else :
6372 self .byt5_small = None
6473
6574 def encode_token_weights (self , token_weight_pairs ):
66- cond , p , extra = super ().encode_token_weights (token_weight_pairs )
75+ out , pooled , extra = super ().encode_token_weights (token_weight_pairs )
76+ tok_pairs = token_weight_pairs ["qwen25_7b" ][0 ]
77+ count_im_start = 0
78+ for i , v in enumerate (tok_pairs ):
79+ elem = v [0 ]
80+ if not torch .is_tensor (elem ):
81+ if isinstance (elem , numbers .Integral ):
82+ if elem == 151644 and count_im_start < 2 :
83+ template_end = i
84+ count_im_start += 1
85+
86+ if out .shape [1 ] > (template_end + 3 ):
87+ if tok_pairs [template_end + 1 ][0 ] == 872 :
88+ if tok_pairs [template_end + 2 ][0 ] == 198 :
89+ template_end += 3
90+
91+ out = out [:, template_end :]
92+
93+ extra ["attention_mask" ] = extra ["attention_mask" ][:, template_end :]
94+ if extra ["attention_mask" ].sum () == torch .numel (extra ["attention_mask" ]):
95+ extra .pop ("attention_mask" ) # attention mask is useless if no masked elements
96+ # noqa: W293
97+
6798 if self .byt5_small is not None and "byt5" in token_weight_pairs :
68- out = self .byt5_small .encode_token_weights (token_weight_pairs ["byt5" ])
69- extra ["conditioning_byt5small" ] = out [0 ]
70- return cond , p , extra
99+ byt5_out = self .byt5_small .encode_token_weights (token_weight_pairs ["byt5" ])
100+ extra ["conditioning_byt5small" ] = byt5_out [0 ]
101+ return out , pooled , extra
102+
103+
71104
72105 def set_clip_options (self , options ):
73106 super ().set_clip_options (options )
@@ -84,14 +117,48 @@ def load_sd(self, sd):
84117 return self .byt5_small .load_sd (sd )
85118 else :
86119 return super ().load_sd (sd )
120+ class HunyuanImageRefinerTEModel (sd1_clip .SD1ClipModel ):
121+ def __init__ (self , device = "cpu" , dtype = None , model_options = {}):
122+ super ().__init__ (device = device , dtype = dtype , name = "qwen25_7b" , clip_model = Qwen25_7BVLIModel , model_options = model_options )
123+
124+ def encode_token_weights (self , token_weight_pairs ):
125+ out , pooled , extra = super ().encode_token_weights (token_weight_pairs )
126+ tok_pairs = token_weight_pairs ["qwen25_7b" ][0 ]
127+ for i , v in enumerate (tok_pairs ):
128+ elem = v [0 ]
129+ if not torch .is_tensor (elem ):
130+ if isinstance (elem , numbers .Integral ):
131+ if elem == 6171 :
132+ template_end = i
133+ break
134+
135+ out = out [:, template_end - 1 :]
136+
137+ extra ["attention_mask" ] = extra ["attention_mask" ][:, template_end - 1 :]
138+ if extra ["attention_mask" ].sum () == torch .numel (extra ["attention_mask" ]):
139+ extra .pop ("attention_mask" ) # attention mask is useless if no masked elements
140+
141+ return out , pooled , extra
142+
143+
144+ def te (byt5 = True , dtype_llama = None , llama_scaled_fp8 = None , refiner = False ):
145+ class HunyuanImageTEModel_ (HunyuanImageTEModel ):
87146
88- def te (byt5 = True , dtype_llama = None , llama_scaled_fp8 = None ):
89- class QwenImageTEModel_ (HunyuanImageTEModel ):
90147 def __init__ (self , device = "cpu" , dtype = None , model_options = {}):
91148 if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options :
92149 model_options = model_options .copy ()
93150 model_options ["qwen_scaled_fp8" ] = llama_scaled_fp8
94151 if dtype_llama is not None :
95152 dtype = dtype_llama
96153 super ().__init__ (byt5 = byt5 , device = device , dtype = dtype , model_options = model_options )
97- return QwenImageTEModel_
154+ class HunyuanImageTEModel_refiner (HunyuanImageRefinerTEModel ):
155+ def __init__ (self , device = "cpu" , dtype = None , model_options = {}):
156+ if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options :
157+ model_options = model_options .copy ()
158+ model_options ["qwen_scaled_fp8" ] = llama_scaled_fp8
159+ if dtype_llama is not None :
160+ dtype = dtype_llama
161+ assert refiner , "refiner must be True"
162+ assert not byt5 , "byt5 must be False"
163+ super ().__init__ (device = device , dtype = dtype , model_options = model_options )
164+ return HunyuanImageTEModel_refiner if refiner else HunyuanImageTEModel_
0 commit comments