Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] introduce LoraBaseMixin to promote reusability. #8670

Merged
merged 19 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ specific language governing permissions and limitations under the License.

# LoRA

LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the UNet, text encoder or both. There are two classes for loading LoRA weights:
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the denoiser, text encoder or both. The denoiser usually corresponds to a UNet ([`UNet2DConditionModel`], for example) or a Transformer ([`SD3Transformer2DModel`], for example). There are several classes for loading LoRA weights:

- [`LoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`LoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

<Tip>

Expand All @@ -29,4 +32,16 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse

## StableDiffusionXLLoraLoaderMixin

[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin

## SD3LoraLoaderMixin

[[autodoc]] loaders.lora.SD3LoraLoaderMixin

## AmusedLoraLoaderMixin

[[autodoc]] loaders.lora.AmusedLoraLoaderMixin

## LoraBaseMixin

[[autodoc]] loaders.lora_base.LoraBaseMixin
10 changes: 5 additions & 5 deletions examples/amused/train_amused.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

import diffusers.optimization
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders import AmusedLoraLoaderMixin
from diffusers.utils import is_wandb_available


Expand Down Expand Up @@ -532,7 +532,7 @@ def save_model_hook(models, weights, output_dir):
weights.pop()

if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
LoraLoaderMixin.save_lora_weights(
AmusedLoraLoaderMixin.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
Expand Down Expand Up @@ -566,11 +566,11 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

if transformer is not None or text_encoder_ is not None:
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir)
AmusedLoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
LoraLoaderMixin.load_lora_into_transformer(
AmusedLoraLoaderMixin.load_lora_into_transformer(
lora_state_dict, network_alphas=network_alphas, transformer=transformer
)

Expand Down
17 changes: 15 additions & 2 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,18 @@ def text_encoder_attn_modules(text_encoder):

if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_sd3"] = ["SD3TransformerLoadersMixin"]

_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"]
_import_structure["lora"] = [
"AmusedLoraLoaderMixin",
"LoraLoaderMixin",
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]

Expand All @@ -69,12 +76,18 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_sd3 import SD3TransformerLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from .lora import (
AmusedLoraLoaderMixin,
LoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin

Expand Down
Loading
Loading