-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] introduce LoraBaseMixin
to promote reusability.
#8670
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
src/diffusers/loaders/lora.py
Outdated
def fuse_lora( | ||
self, | ||
fuse_unet: bool = True, | ||
fuse_text_encoder: bool = True, | ||
lora_scale: float = 1.0, | ||
safe_fusing: bool = False, | ||
adapter_names: Optional[List[str]] = None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that the earlier public methods still work as expected. Also, showcases the power of reusability nicely.
@@ -522,115 +547,18 @@ def load_lora_into_text_encoder( | |||
_pipeline.enable_sequential_cpu_offload() | |||
# Unsafe code /> | |||
|
|||
@classmethod | |||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is only needed for the AmusedPipeline
. It's very safe to say that LoRA + Amused isn't used that much. Plus it is convoluting this class unnecessarily.
To address this, I introduced an AmusedLoaderMixin
in this PR and performed respective changes in the train_amused.py
file. This is the only instance where LoRA + Amused is ever used in our codebase.
@classmethod | ||
def save_lora_weights( | ||
cls, | ||
save_directory: Union[str, os.PathLike], | ||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, | ||
transformer_lora_layers: Dict[str, torch.nn.Module] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for this argument to interfere here.
@Beinsezii could you test with this PR if Thank you in advance. |
Thanks for extending the LoRA functionality to make it work better with transformers. Also thanks for describing your approach. I have not yet done a full review, as I wanted to ask one thing for my better understanding:
Would it have been possible to roll the 2nd point into the 1st? I.e. introduce the new functionality inside of If it is important to keep |
So, for historical reasons,
This way, we will be able to deprecate the
I am okay with renaming the |
Thanks a lot for explaining.
Indeed, this is the part that confused me.
Regarding the last point, you mean a subclass without methods, so basically just an alias (+ potentially deprecation message)? Is this so that user code that relies on
This would probably be better named as
No, I meant that |
No we cannot remove
To keep the scope of this PR relatively manageable, I would prefer to handle the deprecation part in a future PR. I will reflect the rest of the feedback for this PR and ping you once done. As always, thanks for being thorough! This helps a lot! |
Using both the minimal repro code in #8565 and my own app on this branch I tried about a dozen different combinations of multiple loras, adapter strengths, and [non]fuses. Seems to work effectively the same as the XL/SD pipes now for my cases. All done on AMD gfx1100; other compute devices MMV. |
@Beinsezii perfect, thanks for testing! |
LoraUtilsMixin
to promote reusability.LoraBaseMixin
to promote reusability.
USE_PEFT_BACKEND, | ||
is_torch_xla_available, | ||
logging, | ||
replace_example_docstring, | ||
scale_lora_layers, | ||
unscale_lora_layers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to support the results of make fix-copies
.
@@ -25,13 +25,17 @@ | |||
) | |||
|
|||
from ...image_processor import PipelineImageInput, VaeImageProcessor | |||
from ...loaders import SD3LoraLoaderMixin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to support the results of make fix-copies
.
pipeline_class = StableDiffusion3Pipeline | ||
|
||
def get_dummy_components(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need as we more thoroughly cover all kinds of tests in PeftLoraLoaderMixinTests
already.
@BenjaminBossan please take it away for reviewing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for refactoring the LoRA mixin classes to be easier to extend to new pipelines in the future.
This one is honestly quite difficult for me to review, as there have been a lot of changes but mostly they're about moving things around to better organize the code. Therefore, I haven't checked line by line but rather tried to understand the overall direction of the change.
When it comes to correctness, the tests should cover this. There have been some changes to the tests at the same time, so it's not completely clear to me if the tests still cover the exact same cases (or strictly more) as previously. Maybe in the future it would make sense to split the refactor of the code from the refactor of the tests. That way, we can have more confidence that the refactor of the code does not break anything.
Overall, I think the direction of the change is a good one, so from my point of view, this looks good.
@classmethod | ||
def _best_guess_weight_name( | ||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False | ||
def fuse_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this method only exist to map fuse_unet
to fuse_denoiser
? If yes, this could be added as a comment. Same for other similar methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't think this method is needed, we can just map the argument from the method in LoraBaseMixin
we should have this class for loading only and LoraBaseMixin
for the other methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu I hear your point but I feel relatively strongly about having this method to act as an interface (as explained in the comments).
Otherwise, the fuse_lora()
method would need to have fuse_unet
, fuse_transformer
arguments and we may have to support other arguments based on the kind of diffusion backbone newly introduced in the literature (a mamba, for example).
With the current implementation, we have a nice logical segregation of the concepts through the fuse_denoiser
argument. The denoiser
here can be any model architecture in practice as long as there's compatibility and support.
Furthermore, when the users would call fuse_lora()
on an SD3 pipeline they would see fuse_unet
in the docstrings or on their IDEs, too. This can be confusing for them.
So, keeping all of these in mind, I would prefer the current implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can start to deprecate fuse_unet
, fuse_tranformer
etc and only use fuse_denoiser
from this point on, no?
when these fuse_unet
, fuse_tranformer
are passed in **kwargs
, we can map it to fuse_denoiser
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that makes perfect sense to me. Will do that :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
b404a32 should have taken care of it.
logger = logging.get_logger(__name__) | ||
|
||
|
||
class SD3TransformerLoadersMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this basically on the same level as LoraLoaderMixin
but it doesn't make sense to inherit from LoraBaseMixin
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, so this is analogous to https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/unet.py, which implements the LoRA-level functionalities at the model level.
LoraBaseMixin
is for pipelines.
Hence I stated in the PR description:
Currently, we re-implement some methods in both UNet2DLoaderMixin and SD3TransformerLoaderMixin with the "# Copied from ..." mechanism. In a future PR, we can consider introducing a ModelLoaderMixin so that we can share methods like set_adapter() on the ModelMixin level. This should be relatively easy to do.
Fair concern but the refactoring here genuinely affects the tests too. So, I couldn't find a better way out :( For example, we moved
The latter made more sense to me. Hopefully that helps clarify why the test related changes had to be reflected here. |
@classmethod | ||
def _best_guess_weight_name( | ||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False | ||
def fuse_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't think this method is needed, we can just map the argument from the method in LoraBaseMixin
we should have this class for loading only and LoraBaseMixin
for the other methods
text_encoder_module.lora_magnitude_vector[ | ||
adapter_name | ||
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device) | ||
|
||
|
||
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to inherit from LoraLoaderMixin,
no? it defines its own load_lora_weights
and save_lora_weights
.
we are not crazy about inherence but with this we have
pipeline -> StableDiffusionXLLoraLoaderMixin -> LoraLoaderMixin -> LoraBaseMixin
I think it is a bit of too much
can we try to make it work with something like XXPipeline(StableDiffusionXLLoraLoaderMixin, LoraBaseMixin)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shares the load_lora_into_unet()
and load_lora_into_text_encoder()
methods from LoraLoaderMixin
which includes logic for loading Kohya and other popular non-diffusers LoRA checkpoints.
If we proceed with the option you suggested, we would need to have copies of those methods in StableDiffusionXLLoraLoaderMixin
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move these shared methods into LoraBaseMixin
then?
XXPipeline -> XXLoraLoaderMixin -> LoraBaseMixin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, load_lora_into_unet()
isn't applicable for Transformer based diffusion pipelines. Plus StableDiffusionXLLoraLoaderMixin
shares the lora_state_dict()
method from LoraLoaderMixin
as well which is again very specific to the SD family and not shared by other class of pipelines such as SD3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, let's copy the method then
It shares the load_lora_into_unet() and load_lora_into_text_encoder() methods from LoraLoaderMixin which includes logic for loading Kohya and other popular non-diffusers LoRA checkpoints.
I only found one load_lora_into_text_encoder
method on LoraBaseMixin
, no? so the only thing we need is to copy load_lora_into_unet
to StableDiffusionXLLoraLoaderMixin
also, does it make sense to rename LoraLoaderMixin
to StableDiffusionLoraLoaderMixin
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done on ad532a0. LMK.
…)" This reverts commit a2071a1.
* introduce to promote reusability. * up * add more tests * up * remove comments. * fix fuse_nan test * clarify the scope of fuse_lora and unfuse_lora * remove space
What does this PR do?
TL;DR: Introduce a
LoraUBaseMixin
class to reuse LoRA handling methods likeset_adapters()
across different pipelines as much as possible.Long description
Currently, we have a
LoraLoaderMixin
class that implements several crucial methods for dealing with LoRAs. These methods includefuse_lora()
,unfuse_lora()
,set_adapters()
,unload_lora()
,disable_lora()
, and so on.However, this class is very rigid in the sense that it only applies to the Stable Diffusion family of models. This is because it relies on the
unet
component of an underlying pipeline. But we have started to see that Transformers are more and more used as the denoiser. Some popular examples include SD3, Hunyuan DiT, and PixArt-{Sigma,Alpha}. So, we cannot make use of the above methods directly for these pipelines.There are two ways to deal with this:
unet
component replaced withtransformer
.The PR takes the latter approach.
If we just factor in the code changes, we'd notice that this PR substantially reduces the LoCs, gracefully promotes reusability, and reduces the friction to incorporate a new pipeline that uses a non-UNet denoiser backbone.
Currently, we re-implement some methods in both
UNet2DLoaderMixin
andSD3TransformerLoaderMixin
with the "# Copied from ..." mechanism. In a future PR, we can consider introducing aModelLoaderMixin
so that we can share methods likeset_adapter()
on theModelMixin
level. This should be relatively easy to do.This PR will also make it easy to add LoRA support for other pipelines such as Hunyuan DiT, PixArt Sigma, etc. Broadly, all we have to do is implement the following methods:
load_lora_into_denoiser()
where the denoiser can be a UNet or a Transformerload_lora_weights()
PeftLoaderMixin
.I am going to request a review from @BenjaminBossan first. Once things look good, I will request reviews from other maintainers.