Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inference with OFT networks #13692

Merged
merged 26 commits into from
Nov 19, 2023
Merged

Conversation

v0xie
Copy link
Contributor

@v0xie v0xie commented Oct 19, 2023

Description

This PR adds support for inference of OFT networks trained with kohya-ss sd-scripts. The implementation is based on kohya's implementation here: https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py

This is a draft PR because of these major issues:

  • The network remains affects all generations after loading even when supposed to be unloaded
  • It noticeably slows down inference

The current implementation replaces sd_module's forward method with a custom forward method, which I believe is the cause of the network continuing to affect future generations and the speed of inference.

Any suggestions on how to fix these issues is greatly appreciated!

Related links:
OFT project page: https://oft.wyliu.com
OFT official repository: https://github.com/Zeju1997/oft

Screenshots/videos:

Checklist:

@v0xie
Copy link
Contributor Author

v0xie commented Oct 20, 2023

Here's a .safetensors OFT network I trained on part of the Dreambooth dataset: https://huggingface.co/v0xie/sdxl-oft_monster_toy/resolve/main/monster_toy_oft.safetensors
It's a bit under-baked but it should do fine for testing purposes.

25734-2-lora_monster_toy_oft_1 25 photo of a monster_toy toy on the beach

SDXL OFT network trained with kohya-ss sd-scripts
Unique Token / Class Token: monster_toy toy
Linear/Alpha: 8/8
Conv/Alpha: 4/4
Trained for 16 epochs / 400 total steps

@v0xie
Copy link
Contributor Author

v0xie commented Oct 21, 2023

The slowdown during inference is fixed now; anecdotally, it's at least as fast as using any other network!

Instead of replacing the forward method, now we're merging the weights into the weights of the original module (merge_weight function). I'm trying to figure out how to undo this replacement or if it's even necessary to replace the weights.

@v0xie
Copy link
Contributor Author

v0xie commented Oct 22, 2023

The remaining issue of weights continuing to be loaded is fixed, and this PR is ready to merge.

Also fixed is another bug where the multiplier was being applied again in finalize_updown which unpredictably increased the effect of U-Net / Text Encoder multipliers. Kohya's implementation applies the multiplier in the get_weight function, so I've maintained the same behavior for this implementation.

@v0xie v0xie marked this pull request as ready for review October 22, 2023 16:53
@KohakuBlueleaf
Copy link
Collaborator

@v0xie The method kohya implemented is too unefficient.
I will try to make some change based on your PR
(you can check my implementation in lycoris dev branch)

… diag_oft has MultiheadAttn which kohya's doesn't?, attempt create new module based off network_lora.py, errors about tensor dim mismatch
Support LyCORIS diag-oft OFT implementation (minus MultiheadAttention layer), maintains support for kohya-ss OFT
@v0xie
Copy link
Contributor Author

v0xie commented Nov 4, 2023

@KohakuBlueleaf I just added support for diag-oft module.

I don't know how to handle the Linear MultiheadAttention layers yet so those are skipped until I can better understand how they work. The implementation is a bit complicated since I'm maintaining support for kohya OFT module, so maybe it would be better to have two separate files?

@KohakuBlueleaf
Copy link
Collaborator

@KohakuBlueleaf I just added support for diag-oft module.

I don't know how to handle the Linear MultiheadAttention layers yet so those are skipped until I can better understand how they work. The implementation is a bit complicated since I'm maintaining support for kohya OFT module, so maybe it would be better to have two separate files?

Kohya's OFT is mathmetically identical with my implementation (but unefficient)
You just need to ensure the key name have ability to handle 2 formats
And math part just use my implementation.

@v0xie v0xie marked this pull request as draft November 4, 2023 21:52
@KohakuBlueleaf
Copy link
Collaborator

KohakuBlueleaf commented Nov 11, 2023

@v0xie I have do more investigation on it
it looks like what kohya did is actually "Re-scaled Constrained OFT", and what I have done is "OFT".
and in their Re-scaled COFT implementation, the weight is actually "Q". and the weight in my implementation is actually "R".

I will add constrained OFT in the future and I will use Kohya's naming if I use Q in here.
And sry you still need 2 (or, a if-else) different rebuild weight method.

Basically the idea is:

if using_lycoris:
  # direct use weight + I as R
else:
  # use kohya's implementation to rebuild R first
# rebuild weight from R use my implementation

And a warning here.
Kohya's implementation is wrong since their training doesn't have "(RW)X" but "RX"
Will open an issue on sd-scripts

@KohakuBlueleaf
Copy link
Collaborator

@v0xie I changed my implementation
it will only produce "oft_diag" as "R" (or, DR)

so you can just use my logic mentioned above

@v0xie
Copy link
Contributor Author

v0xie commented Nov 12, 2023

@v0xie I changed my implementation it will only produce "oft_diag" as "R" (or, DR)

so you can just use my logic mentioned above

Thank you for this, I'll take a look at reworking the implementation.

@FurkanGozukara
Copy link

did you guys got better results (quality not speed) than proper dreambooth training?

I don't believe it but would like to get your opinion

@KohakuBlueleaf
Copy link
Collaborator

KohakuBlueleaf commented Nov 13, 2023

did you guys got better results (quality not speed) than proper dreambooth training?

I don't believe it but would like to get your opinion

It is hardly depends on how you defined "better"
And hardly depends on your configs too.

It is quite like "high rank lokr" vs "low rank lokr" vs "dreambooth"
low rank lokr/large block oft can be very closed to or better then dreambooth on fidelity (since sometime you want some constrain), but high rank lokr/small block oft will more close to low rank lora, which is better on diversity.

The differences is more introduced by the model capacity (most of time it is basically the trainable param count)

You may want to check the paper of LyCORIS, which discuss on different metrics and it is very clear that when you want a better "fidelity" you need larger model capacity, and when you want better "diversity" you need smaller capacity.

Models with same capacity have subtle differences, really hard to say which one is "better" in this case.
In our experiments it is more like, "some algo is more robust on the required property so it is better" or "some algo is faster/more efficient in time", but not "it actually produce siginificant better result than others"

@KohakuBlueleaf
Copy link
Collaborator

KohakuBlueleaf commented Nov 13, 2023

And we(LyCORIS team) don't have any metric evaluation on OFT yet.
I think it is planned (will ask other member in LyCORIS later)
You can wait for few days or weeks for the result.

@cyber-meow
Copy link

cyber-meow commented Nov 15, 2023

I would like to draw attention on the fact that the inference with OFT (or more generally, any method based on multiplication of additional matrix rather than addition) could be much more complicated if we consider the following two widely adopted practices in the community: using an additional network with a multiplier different from 1 and using multiple networks together.

Using additional network with smaller weight

Let us write $\gamma$ for the scaling factor to apply OFT with smaller/larger weight. I can at least think of three ways to achieve this.

  1. The current implementation pretty much regards OFT as LoRA, and simply does $W' = (I + \gamma(R-I))W$
  2. Provided that $R$ is parameterized as $R=(I+Q)(I-Q)^{-1}$, another natural idea is to define $R_\gamma = (I+\gamma Q)(I-\gamma Q)^{-1}$ and set $W' = I + R_\gamma W$. This may not be so different from the first method if we suppose that $(I+\gamma Q)(I-\gamma Q)^{-1}\approx I + 2\gamma Q$
  3. However, what really sets OFT apart is this viewpoint of fine-tuning some rotation / orthogonal transformation instead of a weight difference. Therefore, what aligns more with this viewpoint is to apply a fractional power $W'=R^{\gamma}W$. In the case of OFT, we have Cayley parameterization that produces matrices from the special orthogonal group. I think to implement matrix power efficiently we may need to rather store it in a different format where we have directly the $2\times 2$ rotational matrices (see e.g. https://math.stackexchange.com/questions/4353632/decomposition-of-son-into-so2-and-inversions). I am not totally sure here, but I believe this approach should be considered as well.

Using multiple additional networks

For this I can think of two different approaches:

  1. Either we apply it directly to the current weight no matter what other networks have been used
  2. Either we apply it to the base weight with no additional networks and then add the difference to the current weight

@v0xie
Copy link
Contributor Author

v0xie commented Nov 16, 2023

@cyber-meow Parameter-Efficient Orthogonal Finetuning via Butterfly Factorization https://arxiv.org/abs/2311.06243

Found this new technical report while looking up more about OFT; maybe it has some insight on how weights should be applied?

@v0xie
Copy link
Contributor Author

v0xie commented Nov 16, 2023

I think this is now ready to merge! Any issues you see @KohakuBlueleaf ?

This PR now has:

Some quirks in this implementation:

  • LyCORIS has MultiheadAttention layers which are not supported
  • kohya-ss' COFT implementation has alpha (which is the constraint), and LyCORIS does not.
    • In either case, constraint is ignored for inference if we assume the weights are pre-constrained / rescaled (which they should be).

@v0xie v0xie marked this pull request as ready for review November 16, 2023 02:47
@AUTOMATIC1111 AUTOMATIC1111 merged commit 2207ef3 into AUTOMATIC1111:dev Nov 19, 2023
3 checks passed
@w-e-w w-e-w mentioned this pull request Dec 4, 2023
@KohakuBlueleaf
Copy link
Collaborator

@v0xie Sorry I haven't checked this things
I need to notice that your implementation for LyCORIS' OFT is completely wrong
You should not do the skew things and should not take it as "oft_blocks"
It is not "oft_blocks" because IT IS the "R"

Don't do any extra operation on it

@KohakuBlueleaf
Copy link
Collaborator

Actually You also do something wrong in Kohya.
You only rebuilt the Q and take it as R.

@w-e-w w-e-w mentioned this pull request Dec 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants