-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support inference with OFT networks #13692
Conversation
Here's a .safetensors OFT network I trained on part of the Dreambooth dataset: https://huggingface.co/v0xie/sdxl-oft_monster_toy/resolve/main/monster_toy_oft.safetensors SDXL OFT network trained with kohya-ss sd-scripts |
The slowdown during inference is fixed now; anecdotally, it's at least as fast as using any other network! Instead of replacing the forward method, now we're merging the weights into the weights of the original module (merge_weight function). I'm trying to figure out how to undo this replacement or if it's even necessary to replace the weights. |
The remaining issue of weights continuing to be loaded is fixed, and this PR is ready to merge. Also fixed is another bug where the multiplier was being applied again in finalize_updown which unpredictably increased the effect of U-Net / Text Encoder multipliers. Kohya's implementation applies the multiplier in the get_weight function, so I've maintained the same behavior for this implementation. |
@v0xie The method kohya implemented is too unefficient. |
… diag_oft has MultiheadAttn which kohya's doesn't?, attempt create new module based off network_lora.py, errors about tensor dim mismatch
Support LyCORIS diag-oft OFT implementation (minus MultiheadAttention layer), maintains support for kohya-ss OFT
@KohakuBlueleaf I just added support for diag-oft module. I don't know how to handle the Linear MultiheadAttention layers yet so those are skipped until I can better understand how they work. The implementation is a bit complicated since I'm maintaining support for kohya OFT module, so maybe it would be better to have two separate files? |
Kohya's OFT is mathmetically identical with my implementation (but unefficient) |
Use same updown implementation for LyCORIS OFT as kohya-ss OFT
@v0xie I have do more investigation on it I will add constrained OFT in the future and I will use Kohya's naming if I use Q in here. Basically the idea is: if using_lycoris:
# direct use weight + I as R
else:
# use kohya's implementation to rebuild R first
# rebuild weight from R use my implementation And a warning here. |
@v0xie I changed my implementation so you can just use my logic mentioned above |
Thank you for this, I'll take a look at reworking the implementation. |
did you guys got better results (quality not speed) than proper dreambooth training? I don't believe it but would like to get your opinion |
It is hardly depends on how you defined "better" It is quite like "high rank lokr" vs "low rank lokr" vs "dreambooth" The differences is more introduced by the model capacity (most of time it is basically the trainable param count) You may want to check the paper of LyCORIS, which discuss on different metrics and it is very clear that when you want a better "fidelity" you need larger model capacity, and when you want better "diversity" you need smaller capacity. Models with same capacity have subtle differences, really hard to say which one is "better" in this case. |
And we(LyCORIS team) don't have any metric evaluation on OFT yet. |
I would like to draw attention on the fact that the inference with OFT (or more generally, any method based on multiplication of additional matrix rather than addition) could be much more complicated if we consider the following two widely adopted practices in the community: using an additional network with a multiplier different from 1 and using multiple networks together. Using additional network with smaller weightLet us write
Using multiple additional networksFor this I can think of two different approaches:
|
@cyber-meow Parameter-Efficient Orthogonal Finetuning via Butterfly Factorization https://arxiv.org/abs/2311.06243 Found this new technical report while looking up more about OFT; maybe it has some insight on how weights should be applied? |
I think this is now ready to merge! Any issues you see @KohakuBlueleaf ? This PR now has:
Some quirks in this implementation:
|
@v0xie Sorry I haven't checked this things Don't do any extra operation on it |
Actually You also do something wrong in Kohya. |
Description
This PR adds support for inference of OFT networks trained with kohya-ss sd-scripts. The implementation is based on kohya's implementation here: https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
This is a draft PR because of these major issues:The network remains affects all generations after loading even when supposed to be unloadedIt noticeably slows down inferenceThe current implementation replaces sd_module's forward method with a custom forward method, which I believe is the cause of the network continuing to affect future generations and the speed of inference.Any suggestions on how to fix these issues is greatly appreciated!Related links:
OFT project page: https://oft.wyliu.com
OFT official repository: https://github.com/Zeju1997/oft
Screenshots/videos:
Checklist: