Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] introduce LoraBaseMixin to promote reusability. #8670

Merged
merged 19 commits into from
Jul 3, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jun 24, 2024

What does this PR do?

TL;DR: Introduce a LoraUBaseMixin class to reuse LoRA handling methods like set_adapters() across different pipelines as much as possible.

Long description

Currently, we have a LoraLoaderMixin class that implements several crucial methods for dealing with LoRAs. These methods include fuse_lora(), unfuse_lora(), set_adapters(), unload_lora(), disable_lora(), and so on.

However, this class is very rigid in the sense that it only applies to the Stable Diffusion family of models. This is because it relies on the unet component of an underlying pipeline. But we have started to see that Transformers are more and more used as the denoiser. Some popular examples include SD3, Hunyuan DiT, and PixArt-{Sigma,Alpha}. So, we cannot make use of the above methods directly for these pipelines.

There are two ways to deal with this:

  • Implement all the above-mentioned methods with the unet component replaced with transformer.
  • Introduce a new base class that can determine the right type of denoiser being used in a pipeline and perform necessary adjustments for us.

The PR takes the latter approach.

If we just factor in the code changes, we'd notice that this PR substantially reduces the LoCs, gracefully promotes reusability, and reduces the friction to incorporate a new pipeline that uses a non-UNet denoiser backbone.

Currently, we re-implement some methods in both UNet2DLoaderMixin and SD3TransformerLoaderMixin with the "# Copied from ..." mechanism. In a future PR, we can consider introducing a ModelLoaderMixin so that we can share methods like set_adapter() on the ModelMixin level. This should be relatively easy to do.

This PR will also make it easy to add LoRA support for other pipelines such as Hunyuan DiT, PixArt Sigma, etc. Broadly, all we have to do is implement the following methods:

  • load_lora_into_denoiser() where the denoiser can be a UNet or a Transformer
  • load_lora_weights()
  • Add LoRA support to the underlying denoiser classes through PeftLoaderMixin.

I am going to request a review from @BenjaminBossan first. Once things look good, I will request reviews from other maintainers.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 327 to 334
def fuse_lora(
self,
fuse_unet: bool = True,
fuse_text_encoder: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that the earlier public methods still work as expected. Also, showcases the power of reusability nicely.

@@ -522,115 +547,18 @@ def load_lora_into_text_encoder(
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

@classmethod
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
Copy link
Member Author

@sayakpaul sayakpaul Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is only needed for the AmusedPipeline. It's very safe to say that LoRA + Amused isn't used that much. Plus it is convoluting this class unnecessarily.

To address this, I introduced an AmusedLoaderMixin in this PR and performed respective changes in the train_amused.py file. This is the only instance where LoRA + Amused is ever used in our codebase.

@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this argument to interfere here.

@sayakpaul
Copy link
Member Author

@Beinsezii could you test with this PR if set_adapters() on SD3 LoRAs work as expected? This is the refactor I mentioned last week.

Thank you in advance.

@BenjaminBossan
Copy link
Member

Thanks for extending the LoRA functionality to make it work better with transformers. Also thanks for describing your approach.

I have not yet done a full review, as I wanted to ask one thing for my better understanding:

There are two ways to deal with this:

  • Implement all the above-mentioned methods with the unet component replaced with transformer.
  • Introduce a new base class that can determine the right type of denoiser being used in a pipeline and perform necessary adjustments for us.

The PR takes the latter approach.

Would it have been possible to roll the 2nd point into the 1st? I.e. introduce the new functionality inside of LoraLoaderMixin instead of creating a new class? SD3LoraLoaderMixin would then inherit from LoraLoaderMixin, same as StableDiffusionXLLoraLoaderMixin for instance, which from my naive point of view would be more intuitive.

If it is important to keep LoraUtilsMixin, I would at the very least consider renaming it. Right now, according to my understanding of the word, it is not a "mixin" class but instead an (abstract) base class and something like class SD3LoraLoaderMixin(LoraUtilsMixin) would be the concrete implementation.

@sayakpaul
Copy link
Member Author

Would it have been possible to roll the 2nd point into the 1st? I.e. introduce the new functionality inside of LoraLoaderMixin instead of creating a new class? SD3LoraLoaderMixin would then inherit from LoraLoaderMixin, same as StableDiffusionXLLoraLoaderMixin for instance, which from my naive point of view would be more intuitive.

So, for historical reasons, LoraLoaderMixin is for dealing with Stable Diffusion. StableDiffusionXLLoraLoaderMixin is for SDXL and so on. I could

  • introduce StableDiffusionLoaderMixin
  • move the methods of LoraLoaderMixin to StableDiffusionLoaderMixin
  • make LoraLoaderMixin a subclass of StableDiffusionLoaderMixin.

This way, we will be able to deprecate the LoraLoaderMixin class properly if needed. I hope it provides further clarification?

If it is important to keep LoraUtilsMixin, I would at the very least consider renaming it. Right now, according to my understanding of the word, it is not a "mixin" class but instead an (abstract) base class and something like class SD3LoraLoaderMixin(LoraUtilsMixin) would be the concrete implementation.

I am okay with renaming the LoraUtilsMixin class. Do you have a suggestion? A little confused here with the latter part. Did you mean SD3LoraLoaderMixin(LoraLoadersMixin) because SD3LoraLoaderMixin(LoraUtilsMixin) is currently the case.

@BenjaminBossan
Copy link
Member

Thanks a lot for explaining.

for historical reasons, LoraLoaderMixin is for dealing with Stable Diffusion. StableDiffusionXLLoraLoaderMixin is for SDXL and so on

Indeed, this is the part that confused me.

I could

  • introduce StableDiffusionLoaderMixin
  • move the methods of LoraLoaderMixin to StableDiffusionLoaderMixin
  • make LoraLoaderMixin a subclass of StableDiffusionLoaderMixin.

This way, we will be able to deprecate the LoraLoaderMixin class properly if needed.

Regarding the last point, you mean a subclass without methods, so basically just an alias (+ potentially deprecation message)? Is this so that user code that relies on LoraLoaderMixin can still function? If this is part of the "public" API of diffusers, that would be all right with me, otherwise I would consider even removing LoraLoaderMixin completely.

I am okay with renaming the LoraUtilsMixin class. Do you have a suggestion?

This would probably be better named as LoraMixinBase or something along this lines, to signify that it's a base class for LoRA mixin classes.

Did you mean SD3LoraLoaderMixin(LoraLoadersMixin) because SD3LoraLoaderMixin(LoraUtilsMixin) is currently the case.

No, I meant that SD3LoraLoaderMixin(LoraUtilsMixin) can be kept, just with a new name like SD3LoraLoaderMixin(LoraMixinBase). This would make it clear that SD3LoraLoaderMixin is a concrete implementation of LoraMixinBase for SD3, StableDiffusionLoaderMixin is the concrete implementation for SD, etc.

@sayakpaul
Copy link
Member Author

Regarding the last point, you mean a subclass without methods, so basically just an alias (+ potentially deprecation message)? Is this so that user code that relies on LoraLoaderMixin can still function? If this is part of the "public" API of diffusers, that would be all right with me, otherwise I would consider even removing LoraLoaderMixin completely.

No we cannot remove LoraLoaderMixin completely as it's quite heavily used.

Is this so that user code that relies on LoraLoaderMixin can still function?

To keep the scope of this PR relatively manageable, I would prefer to handle the deprecation part in a future PR.

I will reflect the rest of the feedback for this PR and ping you once done.

As always, thanks for being thorough! This helps a lot!

@Beinsezii
Copy link
Contributor

@Beinsezii could you test with this PR if set_adapters() on SD3 LoRAs work as expected? This is the refactor I mentioned last week.

Thank you in advance.

Using both the minimal repro code in #8565 and my own app on this branch I tried about a dozen different combinations of multiple loras, adapter strengths, and [non]fuses. Seems to work effectively the same as the XL/SD pipes now for my cases.

All done on AMD gfx1100; other compute devices MMV.

@sayakpaul
Copy link
Member Author

@Beinsezii perfect, thanks for testing!

@sayakpaul sayakpaul changed the title [LoRA] introduce LoraUtilsMixin to promote reusability. [LoRA] introduce LoraBaseMixin to promote reusability. Jun 26, 2024
Comment on lines +33 to +38
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to support the results of make fix-copies.

@@ -25,13 +25,17 @@
)

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to support the results of make fix-copies.

pipeline_class = StableDiffusion3Pipeline

def get_dummy_components(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need as we more thoroughly cover all kinds of tests in PeftLoraLoaderMixinTests already.

@sayakpaul
Copy link
Member Author

@BenjaminBossan please take it away for reviewing.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring the LoRA mixin classes to be easier to extend to new pipelines in the future.

This one is honestly quite difficult for me to review, as there have been a lot of changes but mostly they're about moving things around to better organize the code. Therefore, I haven't checked line by line but rather tried to understand the overall direction of the change.

When it comes to correctness, the tests should cover this. There have been some changes to the tests at the same time, so it's not completely clear to me if the tests still cover the exact same cases (or strictly more) as previously. Maybe in the future it would make sense to split the refactor of the code from the refactor of the tests. That way, we can have more confidence that the refactor of the code does not break anything.

Overall, I think the direction of the change is a good one, so from my point of view, this looks good.

@classmethod
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
def fuse_lora(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this method only exist to map fuse_unet to fuse_denoiser? If yes, this could be added as a comment. Same for other similar methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think this method is needed, we can just map the argument from the method in LoraBaseMixin
we should have this class for loading only and LoraBaseMixin for the other methods

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu I hear your point but I feel relatively strongly about having this method to act as an interface (as explained in the comments).

Otherwise, the fuse_lora() method would need to have fuse_unet, fuse_transformer arguments and we may have to support other arguments based on the kind of diffusion backbone newly introduced in the literature (a mamba, for example).

With the current implementation, we have a nice logical segregation of the concepts through the fuse_denoiser argument. The denoiser here can be any model architecture in practice as long as there's compatibility and support.

Furthermore, when the users would call fuse_lora() on an SD3 pipeline they would see fuse_unet in the docstrings or on their IDEs, too. This can be confusing for them.

So, keeping all of these in mind, I would prefer the current implementation.

Copy link
Collaborator

@yiyixuxu yiyixuxu Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can start to deprecate fuse_unet, fuse_tranformer etc and only use fuse_denoiserfrom this point on, no?
when these fuse_unet, fuse_tranformer are passed in **kwargs, we can map it to fuse_denoiser

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that makes perfect sense to me. Will do that :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b404a32 should have taken care of it.

logger = logging.get_logger(__name__)


class SD3TransformerLoadersMixin:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this basically on the same level as LoraLoaderMixin but it doesn't make sense to inherit from LoraBaseMixin?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, so this is analogous to https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/unet.py, which implements the LoRA-level functionalities at the model level.

LoraBaseMixin is for pipelines.

Hence I stated in the PR description:

Currently, we re-implement some methods in both UNet2DLoaderMixin and SD3TransformerLoaderMixin with the "# Copied from ..." mechanism. In a future PR, we can consider introducing a ModelLoaderMixin so that we can share methods like set_adapter() on the ModelMixin level. This should be relatively easy to do.

@sayakpaul
Copy link
Member Author

When it comes to correctness, the tests should cover this. There have been some changes to the tests at the same time, so it's not completely clear to me if the tests still cover the exact same cases (or strictly more) as previously. Maybe in the future it would make sense to split the refactor of the code from the refactor of the tests. That way, we can have more confidence that the refactor of the code does not break anything.

Fair concern but the refactoring here genuinely affects the tests too. So, I couldn't find a better way out :( For example, we moved set_adapters() to the base class. Now SD and SDXL tests for that method are covered through PeftLoraLoaderMixinTests. But testing for that method in SD3 meant either of the two:

  • Duplicate the test from PeftLoraLoaderMixinTests, making the necessary adjustments for SD3.
  • Make adjustments to PeftLoraLoaderMixinTests and make it a subclass of the SD3 LoRA test suite.

The latter made more sense to me. Hopefully that helps clarify why the test related changes had to be reflected here.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu June 26, 2024 16:52
@sayakpaul
Copy link
Member Author

@yiyixuxu @DN6 this is ready for a review now.

@classmethod
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
def fuse_lora(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think this method is needed, we can just map the argument from the method in LoraBaseMixin
we should have this class for loading only and LoraBaseMixin for the other methods

text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)


class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
Copy link
Collaborator

@yiyixuxu yiyixuxu Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to inherit from LoraLoaderMixin, no? it defines its own load_lora_weights and save_lora_weights.

we are not crazy about inherence but with this we have

pipeline -> StableDiffusionXLLoraLoaderMixin -> LoraLoaderMixin -> LoraBaseMixin

I think it is a bit of too much

can we try to make it work with something like XXPipeline(StableDiffusionXLLoraLoaderMixin, LoraBaseMixin)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shares the load_lora_into_unet() and load_lora_into_text_encoder() methods from LoraLoaderMixin which includes logic for loading Kohya and other popular non-diffusers LoRA checkpoints.

If we proceed with the option you suggested, we would need to have copies of those methods in StableDiffusionXLLoraLoaderMixin.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move these shared methods into LoraBaseMixin then?

XXPipeline -> XXLoraLoaderMixin -> LoraBaseMixin 

Copy link
Member Author

@sayakpaul sayakpaul Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, load_lora_into_unet() isn't applicable for Transformer based diffusion pipelines. Plus StableDiffusionXLLoraLoaderMixin shares the lora_state_dict() method from LoraLoaderMixin as well which is again very specific to the SD family and not shared by other class of pipelines such as SD3.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, let's copy the method then

It shares the load_lora_into_unet() and load_lora_into_text_encoder() methods from LoraLoaderMixin which includes logic for loading Kohya and other popular non-diffusers LoRA checkpoints.

I only found one load_lora_into_text_encoder method on LoraBaseMixin, no? so the only thing we need is to copy load_lora_into_unet to StableDiffusionXLLoraLoaderMixin

also, does it make sense to rename LoraLoaderMixin to StableDiffusionLoraLoaderMixin?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done on ad532a0. LMK.

@sayakpaul sayakpaul merged commit a2071a1 into main Jul 3, 2024
17 of 18 checks passed
@sayakpaul sayakpaul deleted the feat-lora-base-class branch July 3, 2024 01:34
sayakpaul added a commit that referenced this pull request Jul 3, 2024
sayakpaul added a commit that referenced this pull request Jul 3, 2024
)

Revert "[LoRA] introduce `LoraBaseMixin` to promote reusability. (#8670)"

This reverts commit a2071a1.
@sayakpaul sayakpaul restored the feat-lora-base-class branch July 3, 2024 01:35
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* introduce  to promote reusability.

* up

* add more tests

* up

* remove comments.

* fix fuse_nan test

* clarify the scope of fuse_lora and unfuse_lora

* remove space
sayakpaul added a commit that referenced this pull request Dec 23, 2024
)

Revert "[LoRA] introduce `LoraBaseMixin` to promote reusability. (#8670)"

This reverts commit a2071a1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants