17
17
18
18
CLAMP_QUANTILE = 0.99
19
19
20
- # copied from hako-mikan/sd-webui-lora-block-weight/scripts/lora_block_weight.py
21
- BLOCKID26 = ["BASE" ,"IN00" ,"IN01" ,"IN02" ,"IN03" ,"IN04" ,"IN05" ,"IN06" ,"IN07" ,"IN08" ,"IN09" ,"IN10" ,"IN11" ,"M00" ,"OUT00" ,"OUT01" ,"OUT02" ,"OUT03" ,"OUT04" ,"OUT05" ,"OUT06" ,"OUT07" ,"OUT08" ,"OUT09" ,"OUT10" ,"OUT11" ]
22
- BLOCKID17 = ["BASE" ,"IN01" ,"IN02" ,"IN04" ,"IN05" ,"IN07" ,"IN08" ,"M00" ,"OUT03" ,"OUT04" ,"OUT05" ,"OUT06" ,"OUT07" ,"OUT08" ,"OUT09" ,"OUT10" ,"OUT11" ]
23
- BLOCKID12 = ["BASE" ,"IN04" ,"IN05" ,"IN07" ,"IN08" ,"M00" ,"OUT00" ,"OUT01" ,"OUT02" ,"OUT03" ,"OUT04" ,"OUT05" ]
24
- BLOCKID20 = ["BASE" ,"IN00" ,"IN01" ,"IN02" ,"IN03" ,"IN04" ,"IN05" ,"IN06" ,"IN07" ,"IN08" ,"M00" ,"OUT00" ,"OUT01" ,"OUT02" ,"OUT03" ,"OUT04" ,"OUT05" ,"OUT06" ,"OUT07" ,"OUT08" ]
25
- BLOCKNUMS = [12 ,17 ,20 ,26 ]
26
- BLOCKIDS = [BLOCKID12 ,BLOCKID17 ,BLOCKID20 ,BLOCKID26 ]
27
-
28
- BLOCKS = ["encoder" , # BASE
29
- "diffusion_model_input_blocks_0_" , # IN00
30
- "diffusion_model_input_blocks_1_" , # IN01
31
- "diffusion_model_input_blocks_2_" , # IN02
32
- "diffusion_model_input_blocks_3_" , # IN03
33
- "diffusion_model_input_blocks_4_" , # IN04
34
- "diffusion_model_input_blocks_5_" , # IN05
35
- "diffusion_model_input_blocks_6_" , # IN06
36
- "diffusion_model_input_blocks_7_" , # IN07
37
- "diffusion_model_input_blocks_8_" , # IN08
38
- "diffusion_model_input_blocks_9_" , # IN09
39
- "diffusion_model_input_blocks_10_" , # IN10
40
- "diffusion_model_input_blocks_11_" , # IN11
41
- "diffusion_model_middle_block_" , # M00
42
- "diffusion_model_output_blocks_0_" , # OUT00
43
- "diffusion_model_output_blocks_1_" , # OUT01
44
- "diffusion_model_output_blocks_2_" , # OUT02
45
- "diffusion_model_output_blocks_3_" , # OUT03
46
- "diffusion_model_output_blocks_4_" , # OUT04
47
- "diffusion_model_output_blocks_5_" , # OUT05
48
- "diffusion_model_output_blocks_6_" , # OUT06
49
- "diffusion_model_output_blocks_7_" , # OUT07
50
- "diffusion_model_output_blocks_8_" , # OUT08
51
- "diffusion_model_output_blocks_9_" , # OUT09
52
- "diffusion_model_output_blocks_10_" , # OUT10
53
- "diffusion_model_output_blocks_11_" , # OUT11
54
- "embedders" ,
55
- "transformer_resblocks" ]
56
-
57
-
58
- def convert_diffusers_name_to_compvis (key , is_sd2 ):
59
- "copied from AUTOMATIC1111/stable-diffusion-webui/extensions-builtin/Lora/networks.py"
60
-
61
- # put original globals here
62
- re_digits = re .compile (r"\d+" )
63
- re_x_proj = re .compile (r"(.*)_([qkv]_proj)$" )
64
- re_compiled = {}
65
-
66
- suffix_conversion = {
67
- "attentions" : {},
68
- "resnets" : {
69
- "conv1" : "in_layers_2" ,
70
- "conv2" : "out_layers_3" ,
71
- "time_emb_proj" : "emb_layers_1" ,
72
- "conv_shortcut" : "skip_connection" ,
73
- }
74
- } # end of original globals
75
-
76
- def match (match_list , regex_text ):
77
- regex = re_compiled .get (regex_text )
78
- if regex is None :
79
- regex = re .compile (regex_text )
80
- re_compiled [regex_text ] = regex
81
-
82
- r = re .match (regex , key )
83
- if not r :
84
- return False
85
-
86
- match_list .clear ()
87
- match_list .extend ([int (x ) if re .match (re_digits , x ) else x for x in r .groups ()])
88
- return True
89
-
90
- m = []
91
-
92
- if match (m , r"lora_unet_conv_in(.*)" ):
93
- return f'diffusion_model_input_blocks_0_0{ m [0 ]} '
94
-
95
- if match (m , r"lora_unet_conv_out(.*)" ):
96
- return f'diffusion_model_out_2{ m [0 ]} '
97
-
98
- if match (m , r"lora_unet_time_embedding_linear_(\d+)(.*)" ):
99
- return f"diffusion_model_time_embed_{ m [0 ] * 2 - 2 } { m [1 ]} "
100
-
101
- if match (m , r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)" ):
102
- suffix = suffix_conversion .get (m [1 ], {}).get (m [3 ], m [3 ])
103
- return f"diffusion_model_input_blocks_{ 1 + m [0 ] * 3 + m [2 ]} _{ 1 if m [1 ] == 'attentions' else 0 } _{ suffix } "
104
-
105
- if match (m , r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)" ):
106
- suffix = suffix_conversion .get (m [0 ], {}).get (m [2 ], m [2 ])
107
- return f"diffusion_model_middle_block_{ 1 if m [0 ] == 'attentions' else m [1 ] * 2 } _{ suffix } "
108
-
109
- if match (m , r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)" ):
110
- suffix = suffix_conversion .get (m [1 ], {}).get (m [3 ], m [3 ])
111
- return f"diffusion_model_output_blocks_{ m [0 ] * 3 + m [2 ]} _{ 1 if m [1 ] == 'attentions' else 0 } _{ suffix } "
112
-
113
- if match (m , r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv" ):
114
- return f"diffusion_model_input_blocks_{ 3 + m [0 ] * 3 } _0_op"
115
-
116
- if match (m , r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv" ):
117
- return f"diffusion_model_output_blocks_{ 2 + m [0 ] * 3 } _{ 2 if m [0 ]> 0 else 1 } _conv"
118
-
119
- if match (m , r"lora_te_text_model_encoder_layers_(\d+)_(.+)" ):
120
- if is_sd2 :
121
- if 'mlp_fc1' in m [1 ]:
122
- return f"model_transformer_resblocks_{ m [0 ]} _{ m [1 ].replace ('mlp_fc1' , 'mlp_c_fc' )} "
123
- elif 'mlp_fc2' in m [1 ]:
124
- return f"model_transformer_resblocks_{ m [0 ]} _{ m [1 ].replace ('mlp_fc2' , 'mlp_c_proj' )} "
125
- else :
126
- return f"model_transformer_resblocks_{ m [0 ]} _{ m [1 ].replace ('self_attn' , 'attn' )} "
127
-
128
- return f"transformer_text_model_encoder_layers_{ m [0 ]} _{ m [1 ]} "
129
-
130
- if match (m , r"lora_te2_text_model_encoder_layers_(\d+)_(.+)" ):
131
- if 'mlp_fc1' in m [1 ]:
132
- return f"1_model_transformer_resblocks_{ m [0 ]} _{ m [1 ].replace ('mlp_fc1' , 'mlp_c_fc' )} "
133
- elif 'mlp_fc2' in m [1 ]:
134
- return f"1_model_transformer_resblocks_{ m [0 ]} _{ m [1 ].replace ('mlp_fc2' , 'mlp_c_proj' )} "
135
- else :
136
- return f"1_model_transformer_resblocks_{ m [0 ]} _{ m [1 ].replace ('self_attn' , 'attn' )} "
137
-
138
- return key
20
+ ACCEPTABLE = [12 , 17 , 20 , 26 ]
21
+ SDXL_LAYER_NUM = [12 , 20 ]
22
+
23
+ LAYER12 = {
24
+ "BASE" : True ,
25
+ "IN00" : False , "IN01" : False , "IN02" : False , "IN03" : False , "IN04" : True , "IN05" : True ,
26
+ "IN06" : False , "IN07" : True , "IN08" : True , "IN09" : False , "IN10" : False , "IN11" : False ,
27
+ "MID00" : True ,
28
+ "OUT00" : True , "OUT01" : True , "OUT02" : True , "OUT03" : True , "OUT04" : True , "OUT05" : True ,
29
+ "OUT06" : False , "OUT07" : False , "OUT08" : False , "OUT09" : False , "OUT10" : False , "OUT11" : False
30
+ }
31
+
32
+ LAYER17 = {
33
+ "BASE" : True ,
34
+ "IN00" : False , "IN01" : True , "IN02" : True , "IN03" : False , "IN04" : True , "IN05" : True ,
35
+ "IN06" : False , "IN07" : True , "IN08" : True , "IN09" : False , "IN10" : False , "IN11" : False ,
36
+ "MID00" : True ,
37
+ "OUT00" : False , "OUT01" : False , "OUT02" : False , "OUT03" : True , "OUT04" : True , "OUT05" : True ,
38
+ "OUT06" : True , "OUT07" : True , "OUT08" : True , "OUT09" : True , "OUT10" : True , "OUT11" : True ,
39
+ }
40
+
41
+ LAYER20 = {
42
+ "BASE" : True ,
43
+ "IN00" : True , "IN01" : True , "IN02" : True , "IN03" : True , "IN04" : True , "IN05" : True ,
44
+ "IN06" : True , "IN07" : True , "IN08" : True , "IN09" : False , "IN10" : False , "IN11" : False ,
45
+ "MID00" : True ,
46
+ "OUT00" : True , "OUT01" : True , "OUT02" : True , "OUT03" : True , "OUT04" : True , "OUT05" : True ,
47
+ "OUT06" : True , "OUT07" : True , "OUT08" : True , "OUT09" : False , "OUT10" : False , "OUT11" : False ,
48
+ }
49
+
50
+ layer12 = LAYER12 .values ()
51
+ layer17 = LAYER17 .values ()
52
+ layer20 = LAYER20 .values ()
53
+ layer26 = [True ] * 26
54
+ assert len (v for v in layer12 if v ) == 12
55
+ assert len (v for v in layer17 if v ) == 17
56
+ assert len (v for v in layer20 if v ) == 20
57
+ assert len (v for v in layer26 if v ) == 26
58
+
59
+ RE_UPDOWN = re .compile (r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_" )
60
+
61
+
62
+ def get_block_index (lora_name : str , is_sdxl : bool = False ) -> int :
63
+ block_idx = - 1 # invalid lora name
64
+ if not is_sdxl :
65
+ m = RE_UPDOWN .search (lora_name )
66
+ if m :
67
+ g = m .groups ()
68
+ i = int (g [1 ])
69
+ j = int (g [3 ])
70
+ if g [2 ] == "resnets" :
71
+ idx = 3 * i + j
72
+ elif g [2 ] == "attentions" :
73
+ idx = 3 * i + j
74
+ elif g [2 ] == "upsamplers" or g [2 ] == "downsamplers" :
75
+ idx = 3 * i + 2
76
+
77
+ if g [0 ] == "down" :
78
+ block_idx = 1 + idx # 0に該当するLoRAは存在しない
79
+ elif g [0 ] == "up" :
80
+ block_idx = LoRANetwork .NUM_OF_BLOCKS + 1 + idx
81
+ elif "mid_block_" in lora_name :
82
+ block_idx = LoRANetwork .NUM_OF_BLOCKS # idx=12
83
+ else :
84
+ # copy from sdxl_train
85
+ if lora_name .startswith ("lora_unet_" ):
86
+ name = lora_name [len ("lora_unet_" ) :]
87
+ if name .startswith ("time_embed_" ) or name .startswith ("label_emb_" ): # No LoRA
88
+ block_idx = 0 # 0
89
+ elif name .startswith ("input_blocks_" ): # 1-9
90
+ block_idx = 1 + int (name .split ("_" )[2 ])
91
+ elif name .startswith ("middle_block_" ): # 10-12
92
+ block_idx = 10 + int (name .split ("_" )[2 ])
93
+ elif name .startswith ("output_blocks_" ): # 13-21
94
+ block_idx = 13 + int (name .split ("_" )[2 ])
95
+ elif name .startswith ("out_" ): # 22, out, no LoRA
96
+ block_idx = 22
97
+
98
+ return block_idx
139
99
140
100
141
101
def load_state_dict (file_name , dtype ):
@@ -165,10 +125,10 @@ def save_to_file(file_name, state_dict, dtype, metadata):
165
125
torch .save (state_dict , file_name )
166
126
167
127
168
- def merge_lora_models (is_sd2 , models , ratios , lbws , new_rank , new_conv_rank , device , merge_dtype ):
128
+ def merge_lora_models (models , ratios , lbws , new_rank , new_conv_rank , device , merge_dtype ):
169
129
logger .info (f"new rank: { new_rank } , new conv rank: { new_conv_rank } " )
170
130
merged_sd = {}
171
- v2 = None
131
+ v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
172
132
base_model = None
173
133
174
134
if lbws :
@@ -179,12 +139,18 @@ def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, dev
179
139
raise ValueError (f"format of lbws are must be json / 層別適用率はJSON形式で書いてください" )
180
140
assert all (isinstance (lbw , list ) for lbw in lbws ), f"lbws are must be list / 層別適用率はリストにしてください"
181
141
assert len (set (len (lbw ) for lbw in lbws )) == 1 , "all lbws should have the same length / 層別適用率は同じ長さにしてください"
182
- assert all (len (lbw ) in BLOCKNUMS for lbw in lbws ), f"length of lbw are must be in { BLOCKNUMS } / 層別適用率の長さは{ BLOCKNUMS } のいずれかにしてください"
142
+ assert all (len (lbw ) in ACCEPTABLE for lbw in lbws ), f"length of lbw are must be in { ACCEPTABLE } / 層別適用率の長さは{ ACCEPTABLE } のいずれかにしてください"
183
143
assert all (all (isinstance (weight , (int , float )) for weight in lbw ) for lbw in lbws ), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
184
144
185
- BLOCKID = BLOCKIDS [BLOCKNUMS .index (len (lbws [0 ]))]
186
- conditions = [blockid in BLOCKID for blockid in BLOCKID26 ]
187
- BLOCKS_ = [block for block , condition in zip (BLOCKS , conditions ) if condition ]
145
+ layer_num = len (lbws [0 ])
146
+ FLAGS = {
147
+ "12" : layer12 ,
148
+ "17" : layer17 ,
149
+ "20" : layer20 ,
150
+ "26" : layer26 ,
151
+ }[str (layer_num )]
152
+ is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
153
+ TARGET = [i for i , flag in enumerate (FLAGS ) if flag ]
188
154
189
155
for model , ratio , lbw in itertools .zip_longest (models , ratios , lbws ):
190
156
logger .info (f"loading: { model } " )
@@ -196,22 +162,21 @@ def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, dev
196
162
if base_model is None :
197
163
base_model = lora_metadata .get (train_util .SS_METADATA_KEY_BASE_MODEL_VERSION , None )
198
164
165
+ full_lbw = [1 ] * 26
166
+ for weight , index in zip (lbw , TARGET ):
167
+ full_lbw [index ] = weight
168
+
199
169
# merge
200
170
logger .info (f"merging..." )
201
171
for key in tqdm (list (lora_sd .keys ())):
202
172
if "lora_down" not in key :
203
173
continue
204
174
205
175
if lbw :
206
- # keyをlora_unet_down_blocks_0_のようなdiffusers形式から、
207
- # diffusion_model_input_blocks_0_のようなcompvis形式に変換する
208
- compvis_key = convert_diffusers_name_to_compvis (key , is_sd2 )
209
-
210
- block_in_key = [block in compvis_key for block in BLOCKS_ ]
211
- is_lbw_target = any (block_in_key )
176
+ index = 0 if "encoder" in key else get_block_index (key , is_sdxl )
177
+ is_lbw_target = index in TARGET
212
178
if is_lbw_target :
213
- index = [i for i , in_key in enumerate (block_in_key ) if in_key ][0 ]
214
- lbw_weight = lbw [index ]
179
+ lbw_weight = full_lbw [index ]
215
180
216
181
lora_module_name = key [: key .rfind (".lora_down" )]
217
182
@@ -344,7 +309,7 @@ def str_to_dtype(p):
344
309
345
310
new_conv_rank = args .new_conv_rank if args .new_conv_rank is not None else args .new_rank
346
311
state_dict , metadata , v2 , base_model = merge_lora_models (
347
- args .sd2 , args . models , args .ratios , args .lbws , args .new_rank , new_conv_rank , args .device , merge_dtype
312
+ args .models , args .ratios , args .lbws , args .new_rank , new_conv_rank , args .device , merge_dtype
348
313
)
349
314
350
315
logger .info (f"calculating hashes and creating metadata..." )
@@ -390,9 +355,6 @@ def setup_parser() -> argparse.ArgumentParser:
390
355
parser .add_argument (
391
356
"--save_to" , type = str , default = None , help = "destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
392
357
)
393
- parser .add_argument (
394
- "--sd2" , action = "store_true" , help = "set if LoRA models are for SD2 / マージするLoRAモデルがSD2用なら指定します"
395
- )
396
358
parser .add_argument (
397
359
"--models" , type = str , nargs = "*" , help = "LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
398
360
)
0 commit comments