Skip to content

Commit 41357a4

Browse files
solve license incompatibility
1 parent 0a639d5 commit 41357a4

File tree

1 file changed

+99
-137
lines changed

1 file changed

+99
-137
lines changed

networks/svd_merge_lora.py

+99-137
Original file line numberDiff line numberDiff line change
@@ -17,125 +17,85 @@
1717

1818
CLAMP_QUANTILE = 0.99
1919

20-
# copied from hako-mikan/sd-webui-lora-block-weight/scripts/lora_block_weight.py
21-
BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
22-
BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
23-
BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"]
24-
BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"]
25-
BLOCKNUMS = [12,17,20,26]
26-
BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26]
27-
28-
BLOCKS=["encoder", # BASE
29-
"diffusion_model_input_blocks_0_", # IN00
30-
"diffusion_model_input_blocks_1_", # IN01
31-
"diffusion_model_input_blocks_2_", # IN02
32-
"diffusion_model_input_blocks_3_", # IN03
33-
"diffusion_model_input_blocks_4_", # IN04
34-
"diffusion_model_input_blocks_5_", # IN05
35-
"diffusion_model_input_blocks_6_", # IN06
36-
"diffusion_model_input_blocks_7_", # IN07
37-
"diffusion_model_input_blocks_8_", # IN08
38-
"diffusion_model_input_blocks_9_", # IN09
39-
"diffusion_model_input_blocks_10_", # IN10
40-
"diffusion_model_input_blocks_11_", # IN11
41-
"diffusion_model_middle_block_", # M00
42-
"diffusion_model_output_blocks_0_", # OUT00
43-
"diffusion_model_output_blocks_1_", # OUT01
44-
"diffusion_model_output_blocks_2_", # OUT02
45-
"diffusion_model_output_blocks_3_", # OUT03
46-
"diffusion_model_output_blocks_4_", # OUT04
47-
"diffusion_model_output_blocks_5_", # OUT05
48-
"diffusion_model_output_blocks_6_", # OUT06
49-
"diffusion_model_output_blocks_7_", # OUT07
50-
"diffusion_model_output_blocks_8_", # OUT08
51-
"diffusion_model_output_blocks_9_", # OUT09
52-
"diffusion_model_output_blocks_10_", # OUT10
53-
"diffusion_model_output_blocks_11_", # OUT11
54-
"embedders",
55-
"transformer_resblocks"]
56-
57-
58-
def convert_diffusers_name_to_compvis(key, is_sd2):
59-
"copied from AUTOMATIC1111/stable-diffusion-webui/extensions-builtin/Lora/networks.py"
60-
61-
# put original globals here
62-
re_digits = re.compile(r"\d+")
63-
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
64-
re_compiled = {}
65-
66-
suffix_conversion = {
67-
"attentions": {},
68-
"resnets": {
69-
"conv1": "in_layers_2",
70-
"conv2": "out_layers_3",
71-
"time_emb_proj": "emb_layers_1",
72-
"conv_shortcut": "skip_connection",
73-
}
74-
} # end of original globals
75-
76-
def match(match_list, regex_text):
77-
regex = re_compiled.get(regex_text)
78-
if regex is None:
79-
regex = re.compile(regex_text)
80-
re_compiled[regex_text] = regex
81-
82-
r = re.match(regex, key)
83-
if not r:
84-
return False
85-
86-
match_list.clear()
87-
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
88-
return True
89-
90-
m = []
91-
92-
if match(m, r"lora_unet_conv_in(.*)"):
93-
return f'diffusion_model_input_blocks_0_0{m[0]}'
94-
95-
if match(m, r"lora_unet_conv_out(.*)"):
96-
return f'diffusion_model_out_2{m[0]}'
97-
98-
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
99-
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
100-
101-
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
102-
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
103-
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
104-
105-
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
106-
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
107-
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
108-
109-
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
110-
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
111-
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
112-
113-
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
114-
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
115-
116-
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
117-
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
118-
119-
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
120-
if is_sd2:
121-
if 'mlp_fc1' in m[1]:
122-
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
123-
elif 'mlp_fc2' in m[1]:
124-
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
125-
else:
126-
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
127-
128-
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
129-
130-
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
131-
if 'mlp_fc1' in m[1]:
132-
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
133-
elif 'mlp_fc2' in m[1]:
134-
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
135-
else:
136-
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
137-
138-
return key
20+
ACCEPTABLE = [12, 17, 20, 26]
21+
SDXL_LAYER_NUM = [12, 20]
22+
23+
LAYER12 = {
24+
"BASE": True,
25+
"IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True,
26+
"IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False,
27+
"MID00": True,
28+
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True,
29+
"OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False
30+
}
31+
32+
LAYER17 = {
33+
"BASE": True,
34+
"IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True,
35+
"IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False,
36+
"MID00": True,
37+
"OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True,
38+
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True,
39+
}
40+
41+
LAYER20 = {
42+
"BASE": True,
43+
"IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True,
44+
"IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False,
45+
"MID00": True,
46+
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True,
47+
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False,
48+
}
49+
50+
layer12 = LAYER12.values()
51+
layer17 = LAYER17.values()
52+
layer20 = LAYER20.values()
53+
layer26 = [True] * 26
54+
assert len(v for v in layer12 if v) == 12
55+
assert len(v for v in layer17 if v) == 17
56+
assert len(v for v in layer20 if v) == 20
57+
assert len(v for v in layer26 if v) == 26
58+
59+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
60+
61+
62+
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
63+
block_idx = -1 # invalid lora name
64+
if not is_sdxl:
65+
m = RE_UPDOWN.search(lora_name)
66+
if m:
67+
g = m.groups()
68+
i = int(g[1])
69+
j = int(g[3])
70+
if g[2] == "resnets":
71+
idx = 3 * i + j
72+
elif g[2] == "attentions":
73+
idx = 3 * i + j
74+
elif g[2] == "upsamplers" or g[2] == "downsamplers":
75+
idx = 3 * i + 2
76+
77+
if g[0] == "down":
78+
block_idx = 1 + idx # 0に該当するLoRAは存在しない
79+
elif g[0] == "up":
80+
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
81+
elif "mid_block_" in lora_name:
82+
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
83+
else:
84+
# copy from sdxl_train
85+
if lora_name.startswith("lora_unet_"):
86+
name = lora_name[len("lora_unet_") :]
87+
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
88+
block_idx = 0 # 0
89+
elif name.startswith("input_blocks_"): # 1-9
90+
block_idx = 1 + int(name.split("_")[2])
91+
elif name.startswith("middle_block_"): # 10-12
92+
block_idx = 10 + int(name.split("_")[2])
93+
elif name.startswith("output_blocks_"): # 13-21
94+
block_idx = 13 + int(name.split("_")[2])
95+
elif name.startswith("out_"): # 22, out, no LoRA
96+
block_idx = 22
97+
98+
return block_idx
13999

140100

141101
def load_state_dict(file_name, dtype):
@@ -165,10 +125,10 @@ def save_to_file(file_name, state_dict, dtype, metadata):
165125
torch.save(state_dict, file_name)
166126

167127

168-
def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
128+
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
169129
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
170130
merged_sd = {}
171-
v2 = None
131+
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
172132
base_model = None
173133

174134
if lbws:
@@ -179,12 +139,18 @@ def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, dev
179139
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
180140
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
181141
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
182-
assert all(len(lbw) in BLOCKNUMS for lbw in lbws), f"length of lbw are must be in {BLOCKNUMS} / 層別適用率の長さは{BLOCKNUMS}のいずれかにしてください"
142+
assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
183143
assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
184144

185-
BLOCKID = BLOCKIDS[BLOCKNUMS.index(len(lbws[0]))]
186-
conditions = [blockid in BLOCKID for blockid in BLOCKID26]
187-
BLOCKS_ = [block for block, condition in zip(BLOCKS, conditions) if condition]
145+
layer_num = len(lbws[0])
146+
FLAGS = {
147+
"12": layer12,
148+
"17": layer17,
149+
"20": layer20,
150+
"26": layer26,
151+
}[str(layer_num)]
152+
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
153+
TARGET = [i for i, flag in enumerate(FLAGS) if flag]
188154

189155
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
190156
logger.info(f"loading: {model}")
@@ -196,22 +162,21 @@ def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, dev
196162
if base_model is None:
197163
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
198164

165+
full_lbw = [1] * 26
166+
for weight, index in zip(lbw, TARGET):
167+
full_lbw[index] = weight
168+
199169
# merge
200170
logger.info(f"merging...")
201171
for key in tqdm(list(lora_sd.keys())):
202172
if "lora_down" not in key:
203173
continue
204174

205175
if lbw:
206-
# keyをlora_unet_down_blocks_0_のようなdiffusers形式から、
207-
# diffusion_model_input_blocks_0_のようなcompvis形式に変換する
208-
compvis_key = convert_diffusers_name_to_compvis(key, is_sd2)
209-
210-
block_in_key = [block in compvis_key for block in BLOCKS_]
211-
is_lbw_target = any(block_in_key)
176+
index = 0 if "encoder" in key else get_block_index(key, is_sdxl)
177+
is_lbw_target = index in TARGET
212178
if is_lbw_target:
213-
index = [i for i, in_key in enumerate(block_in_key) if in_key][0]
214-
lbw_weight = lbw[index]
179+
lbw_weight = full_lbw[index]
215180

216181
lora_module_name = key[: key.rfind(".lora_down")]
217182

@@ -344,7 +309,7 @@ def str_to_dtype(p):
344309

345310
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
346311
state_dict, metadata, v2, base_model = merge_lora_models(
347-
args.sd2, args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
312+
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
348313
)
349314

350315
logger.info(f"calculating hashes and creating metadata...")
@@ -390,9 +355,6 @@ def setup_parser() -> argparse.ArgumentParser:
390355
parser.add_argument(
391356
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
392357
)
393-
parser.add_argument(
394-
"--sd2", action="store_true", help="set if LoRA models are for SD2 / マージするLoRAモデルがSD2用なら指定します"
395-
)
396358
parser.add_argument(
397359
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
398360
)

0 commit comments

Comments
 (0)