Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse loras #4473

Merged
merged 60 commits into from
Aug 29, 2023
Merged

Fuse loras #4473

merged 60 commits into from
Aug 29, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 4, 2023

What does this PR do?

Great idea from @williamberman to allow fusing lora weights into the original weights

You can try it out with:

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors")
pipe.fuse_lora()

pipe.to(torch_dtype=torch.float16)
pipe.to("cuda")

torch.manual_seed(0)

prompt = "beautiful scenery nature glass bottle landscape, , purple galaxy bottle"
negative_prompt = "text, watermark"

image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=25).images[0]

You can reverse the effect of fuse_lora() by calling pipe.unfuse_lora(). Refer to the test cases to have a better handle on the implications.

TODOS:

  • Conv layers
  • Attention linear layers
  • Text encoder LoRA fusion

BTW @sayakpaul we should probably refactor the attention lora processors to completely remove them and instead just work with the compatible linear layer class in the other attention processors (we'll have to do this anyways for the peft refactor) Done in #4765 by @patrickvonplaten.

@sayakpaul
Copy link
Member

BTW @sayakpaul we should probably refactor the attention lora processors to completely remove them and instead just work with the compatible linear layer class in the other attention processors (we'll have to do this anyways for the peft refactor)

Could you please elaborate? I didn't get it at all.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@apolinario
Copy link
Collaborator

apolinario commented Aug 4, 2023

Kohya has a similar functionality and they have this ratio param: https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py#L33

python networks\merge_lora.py --sd_model ..\model\model.ckpt --save_to ..\lora_train1\model-char1-merged.safetensors --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5

Would it make sense we enable this weighting into the (single) loading too we are doing too? 

@apolinario
Copy link
Collaborator

Also as mentioned internally on Slack, would be great to have a pipe.unet.unfuse_lora() for use-cases where on infra one may want to keep a model warm and swap out LoRAs

@sayakpaul
Copy link
Member

@apolinario FWIW, this PR doesn't concern loading multiple LoRAs and fusing them. It simply fuses the LoRA params and the original model params in a way that doesn't affect the rest of our attention processing mechanism.

We already have issue requests for multiple LoRAs on GH.

@apolinario
Copy link
Collaborator

apolinario commented Aug 5, 2023

@sayakpaul sorry, I was not suggesting that we should tackle loading multiple LoRAs in this PR, I just used the loading multiple LoRAs code example of kohya to showcase the ratios param which was the thing I wanted to ask if we should do smth analogous or not on this PR

I have edited my comment for clarity

@sayakpaul
Copy link
Member

IIUC, our scale parameter can be used to control how much the LoRA parama are being merged to the corresponding original params. What am I missing out on?

@apolinario
Copy link
Collaborator

apolinario commented Aug 5, 2023

Yup you are right, sounds redundant then if there's a general/global scale param for loading the LoRA

Edit: as discussed offline, while we do have a cross_attention_kwargs={"scale": 0.5} as stated here, that happens during inference, so in this PR we would need some sort of a scale param for the model merging

@patrickvonplaten
Copy link
Contributor Author

BTW @sayakpaul we should probably refactor the attention lora processors to completely remove them and instead just work with the compatible linear layer class in the other attention processors (we'll have to do this anyways for the peft refactor)

Could you please elaborate? I didn't get it at all.

For this PR we need to add a specific fuse_lora function to

class LoRAXFormersAttnProcessor(nn.Module):
and
class LoRAAttnProcessor2_0(nn.Module):
to get it working. That's ok for now. However in the midterm we should try to fully delete these two classes and instead adapt:
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
to
class LoRACompatibleLinear(nn.Linear):
(this will need some changes in the loading and training of LoRA then, but we can make it work with fully backwards compatibility I believe - also cc @williamberman )

@apolinario apolinario mentioned this pull request Aug 17, 2023
@patrickvonplaten
Copy link
Contributor Author

@sayakpaul do you want to give this PR a try or do you prefer me to finish it?

@sayakpaul
Copy link
Member

Giving it a try.

@sayakpaul
Copy link
Member

@patrickvonplaten up for a review here.

@sayakpaul
Copy link
Member

With this benchmarking script, I get the following on a V100:

{'fuse': False, 'total_time (ms)': '95874.1', 'memory (mb)': 13572}

{'fuse': True, 'total_time (ms)': '83744.8', 'memory (mb)': 13543}

@apolinario
Copy link
Collaborator

Two questions:

  • Are we supporting a scale param for fusing the LoRA as previously discussed? Or would cross_attention_kwargs={"scale": 0.5} still work with fused LoRAs? 
  • What happens if I fuse two LoRAs and then call unfuse_lora? Does it go back to the original state?

@sayakpaul
Copy link
Member

I am observing some inconsistencies in the tests and unfuse_lora(). I will update the PR once it's ready for review.

@patrickvonplaten
Copy link
Contributor Author

Ah I've found the problem I think. The problem is that we currently don't fuse the text encoder weights, when calling:

unload_lora_weights()

this line:

self._remove_text_encoder_monkey_patch()
is hit and causes the inference to be different.

The PR as it's currently already works correctly for non-textencoder LoRAs such as: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors

you can try by doing:

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors")
pipe.unet.fuse_lora()
#This should have no effect as the LoRA is fused to the `unet` already, but instead it is removing the LoRA's effect
pipe.unload_lora_weights()

pipe.to(torch_dtype=torch.float16)
pipe.to("cuda")

torch.manual_seed(0)
prompt = "a mecha robot"
image = pipe(prompt, num_inference_steps=25).images[0]
image

I think we can actually also also fuse the text encoder weights no? Don't think that should be too difficult - what do you think @sayakpaul ?

@patrickvonplaten
Copy link
Contributor Author

Apart from this the PR looks very nice! Tests that have been added are great

@sayakpaul
Copy link
Member

@apolinario https://gist.github.com/sayakpaul/cd0395669002ae82634e57e5d26cc0e0.

Regular:

image

With fused LoRA:

image

With unload_lora_weights() (no-op):

image

With unfuse_lora():

image

@sayakpaul
Copy link
Member

sayakpaul commented Aug 29, 2023

@patrickvonplaten PR is ready for a review. Would be great if you could do some testing as well.

My only question at this point is how do we allow for scale when the fuse_lora() has been called. We need to think about that and once we have a way, we should write thorough test cases.

@patrickvonplaten patrickvonplaten changed the title [WIP] Fuse loras Fuse loras Aug 29, 2023
@patrickvonplaten
Copy link
Contributor Author

Very cool! Everything works fine and the implementation is clean - nice job!

Yes I think you touch upon a good point here regarding "lora_scale" - we should probably handle this in a new PR.
We currently have a also this 🚧 regarding the "lora_scale": #4751

=> I'd suggest to do the following in a follow-up PR:

  • a) Make sure that "lora_scale" is applied to all LoRA layers. Instead of passing a cross_attention_kwargs through the attention processors, it's probably better to work with a "setter"-method here that iterates through all layers and if the layer is a LoRA layer it just sets the "lora_scale" for the specific class
  • b) Also solve:

My only question at this point is how do we allow for scale when the fuse_lora() has been called. We need to think about that and once we have a way, we should write thorough test cases.
=> I think we can do this by just multiplying the fused weight with that "lora_scale" and it should work

  • c) We should also through a warning if a user fuses more than one LoRA IMO

Feel free to take this PR - otherwise happy to take a look myself :-)

Merging this one now!

@patrickvonplaten patrickvonplaten merged commit c583f3b into main Aug 29, 2023
@patrickvonplaten patrickvonplaten deleted the fuse_loras branch August 29, 2023 07:14
@apolinario
Copy link
Collaborator

apolinario commented Aug 29, 2023

Very nice @sayakpaul, however the robot after the pipe.unfuse_lora() is different than the robot before fusing the LoRA:

From your own gist and example (and I reproduced it here):

Before fusing the LoRA After fusing and unfusing the LoRA
Image 1 Image 2
prompt: a mecha robot - seed: 0 prompt: a mecha robot - seed: 0

Is this expected? Do we know why? I think this can affect hot-unloading use-cases as after pipe.unfuse_lora() my understanding is that users should expect the same image as if the LoRA had never been fused

Comment on lines +120 to +123
self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the reason unfuse gives different results might be because we don't do the computation in full fp32 precision here. In the _fuse_lora function we do the computation in full fp32 precision, but here we don't I think. Can we try to make sure that:

unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]

is always computed in full fp32 precision and only then we lower it potentially to fp16 dtype?

Guess we could also easily check this by casting the whole model to fp32 before doing fuse and unfuse to check cc @apolinario

Copy link
Collaborator

@apolinario apolinario Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good hypothesis but would this explain the behaviour of this happening in some LoRAs but not others, and keeping some residual style?

  1. davizca87/vulcan → different image after unloading
  2. ostris/crayon_style_lora_sdxl → same image after unloading
  3. davizca87/sun-flower → different image after unloading (and also different than the image unlaoded in 1)
  4. TheLastBen/Papercut_SDXL→ same image after unloading

Copy link
Collaborator

@apolinario apolinario Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed cases ❌

davizca87/sun-flower seems to keep some of the sunflower vibe in the background

original generation with lora fused after unfusing (residual effects from the LoRA?)
image1 image2 image3

davizca87/vulcan seems to keep some of the vulcan style in the outlines of the robot

original generation with lora fused after unfusing (residual effects from the LoRA?)
image4 image5 image6

Success cases ✅

ostris/crayon_style_lora_sdxl seems to produce a perceptually identical image after unfusing

original generation with lora fused after unfusing (same as original generation)
image4 image image

TheLastBen/Papercut_SDXL and nerijs/pixel-art-xl also exhibit the same correct behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally always worry about how numerical precision stems through a network and affect the end results. Have seen enough cases because of this to not sleep well at night. So, I will start with what @patrickvonplaten suggested.

FWIW, though, there's actually a fast test that ensures the following trip doesn't have side-effects:

load_lora_weights() -> fuse_lora() -> unload_lora_weights() gives you the outputs you would expect after doing fuse_lora().

Let me know if anything is unclear.

Copy link
Collaborator

@apolinario apolinario Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! I guess an analogous test could be made to address what I reported for a future PR

Namely asserting that the two generate with no LoRA are matching:
generate with no LoRA before unfusingload_lora_weights()fuse_lora()unfuse_lora()generate with no LoRA after unfusing

This is the workflow I've reported above, the unfused unet seemingly still contains somewhat of a residue of the lora

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in: #4833 . Was a tricky issue that was caused by the patched text encoder LoRA layers being fully removed when doing unload_lora and therefore loosing their ability to unfuse. Hope it was ok to help out a bit here @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had started fuse-lora-pt2 and mentioned in Slack that I am looking into it. But okay.

@patrickvonplaten
Copy link
Contributor Author

Very cool! Everything works fine and the implementation is clean - nice job!

Yes I think you touch upon a good point here regarding "lora_scale" - we should probably handle this in a new PR. We currently have a also this 🚧 regarding the "lora_scale": #4751

=> I'd suggest to do the following in a follow-up PR:

  • a) Make sure that "lora_scale" is applied to all LoRA layers. Instead of passing a cross_attention_kwargs through the attention processors, it's probably better to work with a "setter"-method here that iterates through all layers and if the layer is a LoRA layer it just sets the "lora_scale" for the specific class
  • b) Also solve:

My only question at this point is how do we allow for scale when the fuse_lora() has been called. We need to think about that and once we have a way, we should write thorough test cases.
=> I think we can do this by just multiplying the fused weight with that "lora_scale" and it should work

  • c) We should also through a warning if a user fuses more than one LoRA IMO

Feel free to take this PR - otherwise happy to take a look myself :-)

Merging this one now!

@sayakpaul if you have some time to look into this before next week that would be incredible! This way we can make a super nice SDXL LoRA release :-)

@sayakpaul
Copy link
Member

@sayakpaul if you have some time to look into this before next week that would be incredible! This way we can make a super nice SDXL LoRA release :-)

Prioritizing it.

@xhinker
Copy link
Contributor

xhinker commented Sep 23, 2023

I see. My understanding was that we would get that for "free" - my understanding was, once a LoRA is fused to the Unet, loading another LoRA would be the same as loading it LoRA to a base model (given it's loading to a fused Unet). What am I missing?

@apolinario your understanding is correct. However, I think we need to think about that design a bit to not break the API consistency and consider the repercussions of that. I am not suggesting we shouldn't allow it, of course, we should. But the API and/or docs for it need to be very tight to not introduce any confusion for the users.

I agree! We should be careful about introducing new APIs here. A suggestion here would be to do the following: There is no harm in allowing the user to fuse multiple LoRAs, but as soon as two LoRAs are fused, the unfuse LoRA function doesn't work anymore. => What do you think about the idea of keeping an internal counter of how many LoRAs are fused and when this number is > 1, we throw an error when doing unfuse LoRA that states that multiple LoRAs have been fused and that the user needs to reload the whole model instead

I'd say only we have a full PEFT integration, we should have full support for fusing and unfusing multiple LoRAs by name.

Thoughts?

@patrickvonplaten @sayakpaul how about adding a LoRA state dictionary to remember the LoRAs and Scale Value. for example:

lora_state = [
     {
           'lora_name':"LoRA1",
           'lora_scale':0.7
     },
     {
           'lora_name':"LoRA2",
           'lora_scale':0.5
     }
]

When a new LoRA is added into the pipeline, update the scale delta to pipeline if the LoRA exists already, and update the lora_state object at the same time, add a new record if never added before,

Now, whenever user call the unfuse() function, the pipeline will remove the LoRA weights times the lora_state's lora_scale. In theory, it should work, thoughts?

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fuse loras

* initial implementation.

* add slow test one.

* styling

* add: test for checking efficiency

* print

* position

* place model offload correctly

* style

* style.

* unfuse test.

* final checks

* remove warning test

* remove warnings altogether

* debugging

* tighten up tests.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debuging

* debugging

* debugging

* debugging

* suit up the generator initialization a bit.

* remove print

* update assertion.

* debugging

* remove print.

* fix: assertions.

* style

* can generator be a problem?

* generator

* correct tests.

* support text encoder lora fusion.

* tighten up tests.

---------

Co-authored-by: Sayak Paul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Fuse loras

* initial implementation.

* add slow test one.

* styling

* add: test for checking efficiency

* print

* position

* place model offload correctly

* style

* style.

* unfuse test.

* final checks

* remove warning test

* remove warnings altogether

* debugging

* tighten up tests.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debuging

* debugging

* debugging

* debugging

* suit up the generator initialization a bit.

* remove print

* update assertion.

* debugging

* remove print.

* fix: assertions.

* style

* can generator be a problem?

* generator

* correct tests.

* support text encoder lora fusion.

* tighten up tests.

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants