-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuse loras #4473
Fuse loras #4473
Conversation
Could you please elaborate? I didn't get it at all. |
The documentation is not available anymore as the PR was closed or merged. |
Kohya has a similar functionality and they have this
Would it make sense we enable this weighting into the (single) loading too we are doing too? |
Also as mentioned internally on Slack, would be great to have a |
@apolinario FWIW, this PR doesn't concern loading multiple LoRAs and fusing them. It simply fuses the LoRA params and the original model params in a way that doesn't affect the rest of our attention processing mechanism. We already have issue requests for multiple LoRAs on GH. |
@sayakpaul sorry, I was not suggesting that we should tackle loading multiple LoRAs in this PR, I just used the loading multiple LoRAs code example of kohya to showcase the I have edited my comment for clarity |
IIUC, our |
Yup you are right, sounds redundant then if there's a general/global scale param for loading the LoRA Edit: as discussed offline, while we do have a |
For this PR we need to add a specific
diffusers/src/diffusers/models/lora.py Line 107 in aef11cb
|
@sayakpaul do you want to give this PR a try or do you prefer me to finish it? |
Giving it a try. |
@patrickvonplaten up for a review here. |
With this benchmarking script, I get the following on a V100:
|
Two questions:
|
I am observing some inconsistencies in the tests and |
Ah I've found the problem I think. The problem is that we currently don't fuse the text encoder weights, when calling: unload_lora_weights() this line: diffusers/src/diffusers/loaders.py Line 1727 in 53f2e74
The PR as it's currently already works correctly for non-textencoder LoRAs such as: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors you can try by doing: from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors")
pipe.unet.fuse_lora()
#This should have no effect as the LoRA is fused to the `unet` already, but instead it is removing the LoRA's effect
pipe.unload_lora_weights()
pipe.to(torch_dtype=torch.float16)
pipe.to("cuda")
torch.manual_seed(0)
prompt = "a mecha robot"
image = pipe(prompt, num_inference_steps=25).images[0]
image I think we can actually also also fuse the text encoder weights no? Don't think that should be too difficult - what do you think @sayakpaul ? |
Apart from this the PR looks very nice! Tests that have been added are great |
@apolinario https://gist.github.com/sayakpaul/cd0395669002ae82634e57e5d26cc0e0. Regular:With fused LoRA:With
|
@patrickvonplaten PR is ready for a review. Would be great if you could do some testing as well. My only question at this point is how do we allow for |
Very cool! Everything works fine and the implementation is clean - nice job! Yes I think you touch upon a good point here regarding => I'd suggest to do the following in a follow-up PR:
Feel free to take this PR - otherwise happy to take a look myself :-) Merging this one now! |
Very nice @sayakpaul, however the robot after the From your own gist and example (and I reproduced it here):
Is this expected? Do we know why? I think this can affect hot-unloading use-cases as after |
self.w_up = self.w_up.to(device=device, dtype=dtype) | ||
self.w_down = self.w_down.to(device, dtype=dtype) | ||
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] | ||
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the reason unfuse
gives different results might be because we don't do the computation in full fp32 precision here. In the _fuse_lora
function we do the computation in full fp32 precision, but here we don't I think. Can we try to make sure that:
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
is always computed in full fp32 precision and only then we lower it potentially to fp16 dtype?
Guess we could also easily check this by casting the whole model to fp32 before doing fuse and unfuse to check cc @apolinario
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good hypothesis but would this explain the behaviour of this happening in some LoRAs but not others, and keeping some residual style?
davizca87/vulcan
→ different image after unloadingostris/crayon_style_lora_sdxl
→ same image after unloadingdavizca87/sun-flower
→ different image after unloading (and also different than the image unlaoded in 1)TheLastBen/Papercut_SDXL
→ same image after unloading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Failed cases ❌
davizca87/sun-flower
seems to keep some of the sunflower vibe in the background
original generation | with lora fused | after unfusing (residual effects from the LoRA?) |
---|---|---|
![]() |
![]() |
![]() |
davizca87/vulcan
seems to keep some of the vulcan style in the outlines of the robot
original generation | with lora fused | after unfusing (residual effects from the LoRA?) |
---|---|---|
![]() |
![]() |
![]() |
Success cases ✅
ostris/crayon_style_lora_sdxl
seems to produce a perceptually identical image after unfusing
original generation | with lora fused | after unfusing (same as original generation) |
---|---|---|
![]() |
![]() |
![]() |
TheLastBen/Papercut_SDXL
and nerijs/pixel-art-xl
also exhibit the same correct behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally always worry about how numerical precision stems through a network and affect the end results. Have seen enough cases because of this to not sleep well at night. So, I will start with what @patrickvonplaten suggested.
FWIW, though, there's actually a fast test that ensures the following trip doesn't have side-effects:
load_lora_weights() -> fuse_lora() -> unload_lora_weights()
gives you the outputs you would expect after doing fuse_lora()
.
Let me know if anything is unclear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! I guess an analogous test could be made to address what I reported for a future PR
Namely asserting that the two generate with no LoRA
are matching:
generate with no LoRA before unfusing
→ load_lora_weights()
→ fuse_lora()
→ unfuse_lora()
→ generate with no LoRA after unfusing
This is the workflow I've reported above, the unfused unet seemingly still contains somewhat of a residue of the lora
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed in: #4833 . Was a tricky issue that was caused by the patched text encoder LoRA layers being fully removed when doing unload_lora
and therefore loosing their ability to unfuse. Hope it was ok to help out a bit here @sayakpaul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had started fuse-lora-pt2
and mentioned in Slack that I am looking into it. But okay.
@sayakpaul if you have some time to look into this before next week that would be incredible! This way we can make a super nice SDXL LoRA release :-) |
Prioritizing it. |
@patrickvonplaten @sayakpaul how about adding a LoRA state dictionary to remember the LoRAs and Scale Value. for example: lora_state = [
{
'lora_name':"LoRA1",
'lora_scale':0.7
},
{
'lora_name':"LoRA2",
'lora_scale':0.5
}
] When a new LoRA is added into the pipeline, update the scale delta to pipeline if the LoRA exists already, and update the lora_state object at the same time, add a new record if never added before, Now, whenever user call the |
* Fuse loras * initial implementation. * add slow test one. * styling * add: test for checking efficiency * print * position * place model offload correctly * style * style. * unfuse test. * final checks * remove warning test * remove warnings altogether * debugging * tighten up tests. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * denugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debuging * debugging * debugging * debugging * suit up the generator initialization a bit. * remove print * update assertion. * debugging * remove print. * fix: assertions. * style * can generator be a problem? * generator * correct tests. * support text encoder lora fusion. * tighten up tests. --------- Co-authored-by: Sayak Paul <[email protected]>
* Fuse loras * initial implementation. * add slow test one. * styling * add: test for checking efficiency * print * position * place model offload correctly * style * style. * unfuse test. * final checks * remove warning test * remove warnings altogether * debugging * tighten up tests. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * denugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debuging * debugging * debugging * debugging * suit up the generator initialization a bit. * remove print * update assertion. * debugging * remove print. * fix: assertions. * style * can generator be a problem? * generator * correct tests. * support text encoder lora fusion. * tighten up tests. --------- Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
Great idea from @williamberman to allow fusing lora weights into the original weights
You can try it out with:
You can reverse the effect of
fuse_lora()
by callingpipe.unfuse_lora()
. Refer to the test cases to have a better handle on the implications.TODOS:
BTW @sayakpaul we should probably refactor the attention lora processors to completely remove them and instead just work with the compatible linear layer class in the other attention processors (we'll have to do this anyways for the peft refactor)Done in #4765 by @patrickvonplaten.