Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Lora Block Weight (LBW) to svd_merge_lora.py #1575

Merged
merged 4 commits into from
Sep 13, 2024

Conversation

terracottahaniwa
Copy link
Contributor

@terracottahaniwa terracottahaniwa commented Sep 7, 2024

I had added support that lora block weight to svd_merge_lora.py.

  • --sd2 and --lbws options are added.
  • --sd2 option is used to parse key names from diffusers format to compvis format. (removed)
  • --lbws is used to pass the block weights as follows;
python svd_merge_lora.py --save_precision bf16 --precision float \
--save_to output.safetensors --models input1.safetensors input2.safetensors \
--ratios 1 1 --lbws "[1,1,1,1,1,1,1,1,1,1,1,1]" "[1,1,1,1,1,1,1,1,1,1,1,1]" --new_rank 1

@kohya-ss
Copy link
Owner

kohya-ss commented Sep 9, 2024

Thank you for this! This is really nice feature.

However, stable-diffusion-webui is licensed under AGPL, and the license for hako-mikan's code is also unclear. So it seems we can't use the code as is in this repository, and we'll need to rewrite it into our own code.

In addition, currently we are determining whether it is a v2 model from metadata, but it seems safer to give it as an argument like in your code. However, I think --v2 is the best option name, along with other scripts.

I can make these changes myself, but it will take some time.

@JujoHotaru
Copy link

XL対応のLoRAやデータモデル開発を行っている十条と申します。
実はこのPR、SuperMergerのXL LoRAの階層マージがうまく動かず困っていた私がDiscordで相談し、実装していただいたものでした。対応検討いただけるとのことで期待しております。

@terracottahaniwa
Copy link
Contributor Author

Thanks for the reply.
I see, I will going to try to solve the license incompatibility issue.

@kohya-ss
Copy link
Owner

I think convert_diffusers_name_to_compvis is used to get the block index from the weight key name, so can the following function be used instead?

sd-scripts/networks/lora.py

Lines 719 to 755 in 65b8a06

def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
block_idx = -1 # invalid lora name
if not is_sdxl:
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
else:
# copy from sdxl_train
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
block_idx = 0 # 0
elif name.startswith("input_blocks_"): # 1-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 10-12
block_idx = 10 + int(name.split("_")[2])
elif name.startswith("output_blocks_"): # 13-21
block_idx = 13 + int(name.split("_")[2])
elif name.startswith("out_"): # 22, out, no LoRA
block_idx = 22
return block_idx

Or the following function may be able to use for SD1/2:

def convert_unet_state_dict_to_sd(v2, unet_state_dict):

@terracottahaniwa
Copy link
Contributor Author

Thanks to show me the pointer!
I'll going to try this.

@terracottahaniwa
Copy link
Contributor Author

I rewrote the code using the get_block_index() function. During the code review, I found that the code does not depend on whether the model is SD2 or Not. Therefore, I removed the --sd2 option again.

@kohya-ss
Copy link
Owner

Sorry, get_block_index returns an index for specifying the block-wise learning rate in sd-scripts, so the value is different from hakomikan's index (e.g., get_block_index returns a different index for each middle block). So some conversion is needed.

Please wait a moment while I prepare the conversion code.

@terracottahaniwa
Copy link
Contributor Author

terracottahaniwa commented Sep 11, 2024

I'm sorry. I was misunderstood about the function.

@terracottahaniwa
Copy link
Contributor Author

I also noticed that Hakomikan's layer selection criteria is to have attention blocks.

lora = safetensors.safe_open(r"sdxl.safetensors", framework="pt")
s = []
for key in lora.keys():
    if "lora_down" not in key:
        continue
    if "attn" in key:
        m = re.search(r"_(\d+)_(\d)_", key)
        if m:
            s.append(key[0:m.end()])
print(sorted(list(set(s))))
"""
[
    'lora_unet_input_blocks_4_1_', 
    'lora_unet_input_blocks_5_1_', 
    'lora_unet_input_blocks_7_1_', 
    'lora_unet_input_blocks_8_1_', 
    'lora_unet_output_blocks_0_1_', 
    'lora_unet_output_blocks_1_1_', 
    'lora_unet_output_blocks_2_1_', 
    'lora_unet_output_blocks_3_1_', 
    'lora_unet_output_blocks_4_1_', 
    'lora_unet_output_blocks_5_1_'
]
"""

So, maybe, BASE has 'encoder', MID has 'middle', and other target layers have 'attn'. Thus, if the key contains strings like 'lora_unet_input_blocks_4_1_', Is it okay to consider this key as the target?

alternative implementation
main...terracottahaniwa:sd-scripts:dev

@kohya-ss
Copy link
Owner

I think the reason why hakomikan checks only the attention block is to extract only one LoRA key per block, so I don't think we need to worry too much about it.

So, maybe, BASE has 'encoder', MID has 'middle', and other target layers have 'attn'. Thus, if the key contains strings like 'lora_unet_input_blocks_4_1_', Is it okay to consider this key as the target?

That's absolutely true.

SDXL LoRA keys are based on CompVis (SAI) so they are easy to convert, but SD LoRA keys are based on Diffusers so they are difficult to convert.

I believe the below script will give us the correct lbw block index, can you please confirm?

import re
from safetensors.torch import load_file

RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")


def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
    # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
    if "text_model_encoder_" in lora_name:  # LoRA for text encoder
        return 0

    # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
    block_idx = -1  # invalid lora name
    if not is_sdxl:
        NUM_OF_BLOCKS = 12  # up/down blocks
        m = RE_UPDOWN.search(lora_name)
        if m:
            g = m.groups()
            up_down = g[0]
            i = int(g[1])
            j = int(g[3])
            if up_down == "down":
                if g[2] == "resnets" or g[2] == "attentions":
                    idx = 3 * i + j + 1
                elif g[2] == "downsamplers":
                    idx = 3 * (i + 1)
                else:
                    return block_idx  # invalid lora name
            elif up_down == "up":
                if g[2] == "resnets" or g[2] == "attentions":
                    idx = 3 * i + j
                elif g[2] == "upsamplers":
                    idx = 3 * i + 2
                else:
                    return block_idx  # invalid lora name

            if g[0] == "down":
                block_idx = 1 + idx  # 1-based index, down block index
            elif g[0] == "up":
                block_idx = 1 + NUM_OF_BLOCKS + 1 + idx  # 1-based index, num blocks, mid block, up block index

        elif "mid_block_" in lora_name:
            block_idx = 1 + NUM_OF_BLOCKS  # 1-based index, num blocks, mid block
    else:
        if lora_name.startswith("lora_unet_"):
            name = lora_name[len("lora_unet_") :]
            if name.startswith("time_embed_") or name.startswith("label_emb_"):  # 1, No LoRA in sd-scripts
                block_idx = 1
            elif name.startswith("input_blocks_"):  # 1-8 to 2-9
                block_idx = 1 + int(name.split("_")[2])
            elif name.startswith("middle_block_"):  # 10
                block_idx = 10
            elif name.startswith("output_blocks_"):  # 0-8 to 11-19
                block_idx = 11 + int(name.split("_")[2])
            elif name.startswith("out_"):  # 20, No LoRA in sd-scripts
                block_idx = 20

    return block_idx


if __name__ == "__main__":
    print("SD1 LoRA")
    sd1_lora_path ="path/to/sd1_lora.safetensors"
    sd1_lora_sd = load_file(sd1_lora_path)
    block_indices_shown = set()
    for key in sd1_lora_sd.keys():
        lora_name = key.split(".")[0]

        # It may be a good idea to add the conversion to ComVis key here and check it.

        lbw_block_idx = get_lbw_block_index(lora_name, is_sdxl=False)
        if lbw_block_idx not in block_indices_shown:
            block_indices_shown.add(lbw_block_idx)
            print(f"{lora_name}: {lbw_block_idx}")

    print("\nSDXL LoRA")
    sdxl_lora_path = "path/to/sdxl_lora.safetensors"
    sdxl_lora_sd = load_file(sdxl_lora_path)
    block_indices_shown = set()
    for key in sdxl_lora_sd.keys():
        lora_name = key.split(".")[0]
        lbw_block_idx = get_lbw_block_index(lora_name, is_sdxl=True)
        if lbw_block_idx not in block_indices_shown:
            block_indices_shown.add(lbw_block_idx)
            print(f"{lora_name}: {lbw_block_idx}")

The output should look like this:

SD1 LoRA
lora_unet_down_blocks_0_attentions_0_proj_in: 2
lora_unet_down_blocks_0_attentions_1_proj_in: 3
lora_unet_down_blocks_0_downsamplers_0_conv: 4
lora_unet_down_blocks_1_attentions_0_proj_in: 5
lora_unet_down_blocks_1_attentions_1_proj_in: 6
lora_unet_down_blocks_1_downsamplers_0_conv: 7
lora_unet_down_blocks_2_attentions_0_proj_in: 8
lora_unet_down_blocks_2_attentions_1_proj_in: 9
lora_unet_down_blocks_2_downsamplers_0_conv: 10
lora_unet_down_blocks_3_resnets_0_conv1: 11
lora_unet_down_blocks_3_resnets_1_conv1: 12
lora_unet_mid_block_attentions_0_proj_in: 13
lora_unet_up_blocks_0_resnets_0_conv1: 14
lora_unet_up_blocks_0_resnets_1_conv1: 15
lora_unet_up_blocks_0_resnets_2_conv1: 16
lora_unet_up_blocks_1_attentions_0_proj_in: 17
lora_unet_up_blocks_1_attentions_1_proj_in: 18
lora_unet_up_blocks_1_attentions_2_proj_in: 19
lora_unet_up_blocks_2_attentions_0_proj_in: 20
lora_unet_up_blocks_2_attentions_1_proj_in: 21
lora_unet_up_blocks_2_attentions_2_proj_in: 22
lora_unet_up_blocks_3_attentions_0_proj_in: 23
lora_unet_up_blocks_3_attentions_1_proj_in: 24
lora_unet_up_blocks_3_attentions_2_proj_in: 25

SDXL LoRA
lora_unet_input_blocks_1_0_emb_layers_1: 2
lora_unet_input_blocks_2_0_emb_layers_1: 3
lora_unet_input_blocks_3_0_op: 4
lora_unet_input_blocks_4_0_emb_layers_1: 5
lora_unet_input_blocks_5_0_emb_layers_1: 6
lora_unet_input_blocks_6_0_op: 7
lora_unet_input_blocks_7_0_emb_layers_1: 8
lora_unet_input_blocks_8_0_emb_layers_1: 9
lora_unet_middle_block_0_emb_layers_1: 10
lora_unet_output_blocks_0_0_emb_layers_1: 11
lora_unet_output_blocks_1_0_emb_layers_1: 12
lora_unet_output_blocks_2_0_emb_layers_1: 13
lora_unet_output_blocks_3_0_emb_layers_1: 14
lora_unet_output_blocks_4_0_emb_layers_1: 15
lora_unet_output_blocks_5_0_emb_layers_1: 16
lora_unet_output_blocks_6_0_emb_layers_1: 17
lora_unet_output_blocks_7_0_emb_layers_1: 18
lora_unet_output_blocks_8_0_emb_layers_1: 19

@JujoHotaru
Copy link

本件ご対応進めていただいておりますところですが、LoRA同士のマージのほか、XLモデルへのLoRAマージ時(sdxl_merge_lora.py)でもlbws指定ができるとデータモデル開発において大変はかどります。
よろしければそちらもご検討いただければ幸いです。

@terracottahaniwa
Copy link
Contributor Author

Thank you for correcting the function! It seems to work well.
I rewrote the code again using get_lbw_block_index().

About the sdxl_merge_lora.py, there's another pull request (#1580) currently in progress.
So maybe we should wait a bit for that PR to resolve as well.

@kohya-ss kohya-ss merged commit 734d2e5 into kohya-ss:main Sep 13, 2024
1 check failed
@kohya-ss
Copy link
Owner

sdxl_merge_lora.pyでもLBWに対応いたしました。動作をご確認いただければ幸いです。

@JujoHotaru
Copy link

@kohya-ss

いつもスクリプトにお世話になっております。
実装いただきました本機能、さっそくLoRA開発に活用させていただいているのですが、少々不可解な動作が発生しており、不具合なのか仕様なのかがわからず、ご報告させていただきます。

現在、素材となる4つのLoRAを、強度・LBWそれぞれ個別にマージして1個の完成形LoRAにまとめようとしています。
プロンプトではこんな形になっています。AAAA/BBBB/CCCCは強度0.2、DDDDは強度1での適用です。

<lora:AAAA:0.2:lbw=0,0,1,1,1,1,1,1,1,0,0,0>
<lora:BBBB:0.2:lbw=0,0,1,1,1,1,1,1,1,0,0,0>
<lora:CCCC:0.2:lbw=0,0,1,1,1,1,1,1,1,0,0,0>
<lora:DDDD:1:lbw=0,0,1,1,1,1,1,1,1,0,0,0>

これを今回の機能を用いてマージしようとしており、svd_merge_loraへ与える引数はこのようにしています。

--ratios 0.2 0.2 0.2 1 --lbws [0,0,0,0,0,0,1,0,1,1,1,1,1,1,0,0,0,0,0,0] [0,0,0,0,0,0,1,0,1,1,1,1,1,1,0,0,0,0,0,0] [0,0,0,0,0,0,1,0,1,1,1,1,1,1,0,0,0,0,0,0] [0,0,0,0,0,0,1,0,1,1,1,1,1,1,0,0,0,0,0,0] 

これを動作させると、LoRAはできあがるのですが、それを適用した結果が、上記プロンプト(LoRA4個指定)で生成した場合と大きく異なったものになってしまいました。
あくまで近似なので、若干の差異は生じる仕様と理解していますが、まったく見た目が違う結果だった(背景も構図も荒れ荒れになってしまった)ので、何か不具合的なことになってしまっているのではないかと思いました。
結果を調べてみると、どうもAAAA/BBBB/CCCCの強度を「0.2」ではなく「1」として生成した場合の結果によく似ていることから、ratiosの指定が効いていない?のではないかと推測したのですが、それならばとratiosはすべて1に変え、lbwのほうの1をすべて0.2にして実行してみても、結果は変わらないようでした。

続いて、SuperMergerを使って、「lora:AAAA:1:lbw=0,0,1,1,1,1,1,1,1,0,0,0」、つまり強度1のまま不要階層だけカットしたものを単体でLoRA化(「AAAA_v2」)し、BBBB/CCCC/DDDDも同様に処理し、その4つをlbws指定無し(「--ratios 0.2 0.2 0.2 1」だけ)で合成したLoRAを作ってみたところ、こちらはプロンプトでLoRA4つ指定して生成した場合にほぼ同等な結果を得られました。

このため、lbwsを使った場合の計算結果に何か私の意図しているものと異なる結果が生じる部分があるようなのですが、何かわかることなどはございますでしょうか?

@kohya-ss
Copy link
Owner

LBWの値を26個に変換するとき、LoRA Block Weightのコードでは初期値0で指定部のみ値を設定しているようです。

https://github.com/hako-mikan/sd-webui-lora-block-weight/blob/4d94d247a6e6d77c11c2d536a5138a54f4d9e500/scripts/lora_block_weight.py#L1236-L1241

svd_merge_lora.pyの現在の実装では初期値1のようです。

if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")

もしかしたらこの違いかもしれません。試しに26個の値を引数で指定して比較してみていただけますでしょうか。

@JujoHotaru
Copy link

JujoHotaru commented Sep 18, 2024

ご確認いただきありがとうございます。
書き忘れておりましたが対象はSDXLのLoRAで、実は上記コメントの「svd_merge_loraに与える引数」のところに記載しているように、すでに不足分は0で埋めて20個分のlbwとしたものを引数に与えておりました。(LoRA BlockWeightの動作に合わせるため)
SDXLなので20個でよいかと思うのですが、26個指定したほうがよいでしょうか?ひとまず試してみます。

@JujoHotaru
Copy link

lbw26個指定のLoRAを作ってみましたが、出力結果はlbw20個指定で出したものと同じでした。

@kohya-ss
Copy link
Owner

すみません、12個か20個で指定しないとSDXL扱いではなくなるので、block indexがおかしくなりますね。

SuperMergerの"same to Strength"オプションは指定されているでしょうか。こちらが指定されていると、数式が変わってくるようです。

SuperMergerのコードを見ているのですが、非常に難解でちょっと原因がわかりかねています……。

@JujoHotaru
Copy link

他にも旺盛な開発活動をされている中で調査ご対応いただきありがとうございます。
Same to Strengthですが、他LoRAの開発においては、オンのほうが良い結果になる場合とその逆の場合の両方あり、主観で使い分けておりました。

ただ、今回のケース、上述した方法でLBWだけ制限したLoRAを生成する場合、強度は1.0で、1.0だと当該オプションはオンでもオフでも生成されたLoRAの適用結果画像はピクセルパーフェクトで全く同一になるようでした。
報告したケースではSame to StrengthオフのLoRAを最終的なマージに使用しております。

@kohya-ss
Copy link
Owner

ありがとうございます。SuperMergerのコードを見ても、ratioが1.0の場合、"same to Strength"がONでもOFFでも、同じ結果になるようです。

となるとsvd_merge_lora.pyのratioの扱いには問題なく、LBWの重みの指定の方に問題がありそうです。

SuperMergerの以下のいずれかにprint文を入れて、値を確認していただくことは可能でしょうか。

https://github.com/hako-mikan/sd-webui-supermerger/blob/843ca282948dbd3fac1246fcb1b66544a371778b/scripts/mergers/pluslora.py#L483

https://github.com/hako-mikan/sd-webui-supermerger/blob/843ca282948dbd3fac1246fcb1b66544a371778b/scripts/mergers/pluslora.py#L546

以下のように追加していただければ、LoRAのキーごとの、block indexが確認できると思います。

            print(key, blockfromkey(key, LBLCOKS26,isv2)) # この行を追加
            ratio = ratios[blockfromkey(key, LBLCOKS26,isv2)]

@JujoHotaru
Copy link

かしこまりました。早速ログを取ってみました。(添付)

log.txt

「AAAA:2:0,0,1,1,1,1,1,1,1,0,0,0」を入力とした場合のデータになります。試しに強度は2.0にしてあります。

@JujoHotaru
Copy link

なお、もし現物データがあったほうがよいようでしたらご提供可能です。

@kohya-ss
Copy link
Owner

kohya-ss commented Sep 18, 2024

ありがとうございます。ログを拝見したところ、input_blocks 1~8が2~9、middleが13、output_blocks 0~8が14~22、というマッピングになっているようですね。このあたりがsvd_merge_lora.pyのLAYER20などの処理で正しく反映されているか、確認してみます。

@terracottahaniwa
Copy link
Contributor Author

terracottahaniwa commented Sep 19, 2024

Oops, maybe my mistake.
The lbw weight might multiply only up weight.

https://github.com/hako-mikan/sd-webui-lora-block-weight/blob/4d94d247a6e6d77c11c2d536a5138a54f4d9e500/scripts/lora_block_weight.py#L1123

            if lbw:
                index = get_lbw_block_index(key, is_sdxl)
                is_lbw_target = index in LBW_TARGET_IDX
                if is_lbw_target:
-                   scale *= lbw_weights[index]  # keyがlbwの対象であれば、lbwの重みを掛ける
+                   up_weight *= lbw_weights[index]  # lbwの対象であれば、up_weightにlbwの重みを掛ける

@terracottahaniwa
Copy link
Contributor Author

terracottahaniwa commented Sep 19, 2024

I apologize. I had incorrectly recalled the details of the research that was conducted when I made my tool.
In that tool's implementation, I had multiplied to only the up weights.😫

https://github.com/terracottahaniwa/apply-lora-block-weight/blob/main/apply_lora_block_weight.py#L80-L81

@kohya-ss
Copy link
Owner

(up @ down) * ratio(up * ratio) @ downは同じ値になるはずですので、そこは問題なさそうです。

import torch

# スカラーの場合
up = torch.tensor([[1, 2], [3, 4]])
down = torch.tensor([[5, 6], [7, 8]])
ratio = 2

w1 = (up @ down) * ratio
w2 = (up * ratio) @ down

print(torch.all(w1 == w2))  # True

@kohya-ss
Copy link
Owner

どうもLBWのSDXLのblock indexは値が飛ぶ特殊な番号になっているようです。LBWと一致するように修正してみましたので、お試しいただけますでしょうか。

@JujoHotaru
Copy link

JujoHotaru commented Sep 19, 2024

早速最新コミットで試してみました。
結果、4つのLoRAを指定していた場合とほぼ同等の出力結果が得られる合成LoRAになりました!無事に修正できているようです。
調査・ご対応ありがとうございました。効率上昇で開発ペースを上げられそうです。
(参考までに、開発していたLoRAはこちらでした。https://x.com/JujoHotaru/status/1836753164449817065

@JujoHotaru
Copy link

ちなみに、sdxl_merge_loraでモデルにLoRAをマージする場合についても同様の問題はございましたでしょうか?

@kohya-ss
Copy link
Owner

無事に動いたようで幸いです。

ちなみに、sdxl_merge_loraでモデルにLoRAをマージする場合についても同様の問題はございましたでしょうか?

block indexを取得する処理が共通ですので、同じ問題があったはずですが、今回の修正で同時に修正されたかと思います。

@JujoHotaru
Copy link

承知しました。引き続き活用させていただきます。

@JujoHotaru
Copy link

追伸ですが、今回修正版でマージしたLoRAは、私の最初の報告の最後の手順(あらかじめSuperMergerで階層だけ制限しておいたLoRAを強度だけ指定でマージ)で作ったLoRAとハッシュが完全一致しておりました。
理論的には同じ計算手順ですので順当なのだと思いますが、一応ご報告させていただきます。

@terracottahaniwa
Copy link
Contributor Author

Simply, I need to relearning! 😂
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants