Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] introduce LoraBaseMixin to promote reusability. #8774

Merged
merged 59 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
b66885b
introduce to promote reusability.
sayakpaul Jun 24, 2024
124b698
up
sayakpaul Jun 24, 2024
8828863
Merge branch 'main' into feat-lora-base-class
sayakpaul Jun 24, 2024
31bdbbf
Merge branch 'main' into feat-lora-base-class
sayakpaul Jun 24, 2024
bb03165
add more tests
sayakpaul Jun 24, 2024
d60445c
Merge branch 'main' into feat-lora-base-class
sayakpaul Jun 24, 2024
5ce5c8b
Merge branch 'main' into feat-lora-base-class
sayakpaul Jun 25, 2024
6448e92
up
sayakpaul Jun 26, 2024
865b94f
up
sayakpaul Jun 26, 2024
b68e385
Merge branch 'main' into feat-lora-base-class
sayakpaul Jun 26, 2024
4af4a8d
remove comments.
sayakpaul Jun 26, 2024
1c80fa7
fix fuse_nan test
sayakpaul Jun 26, 2024
1e5a5ed
Merge branch 'main' into feat-lora-base-class
sayakpaul Jun 26, 2024
6870915
clarify the scope of fuse_lora and unfuse_lora
sayakpaul Jun 26, 2024
45404e9
resolve conflicts.
sayakpaul Jun 27, 2024
6658771
remove space
sayakpaul Jun 27, 2024
0990909
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 1, 2024
13f46ea
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 2, 2024
6ccfc35
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 3, 2024
f765c0b
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 3, 2024
d84c51a
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 3, 2024
b404a32
rewrite fuse_lora a bit.
sayakpaul Jul 3, 2024
ad532a0
feedback
sayakpaul Jul 3, 2024
7e9b4e7
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 3, 2024
1ce7690
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 3, 2024
9a387e4
copy over load_lora_into_text_encoder.
sayakpaul Jul 3, 2024
108a24d
address dhruv's feedback.
sayakpaul Jul 3, 2024
6fcea65
fix-copies
sayakpaul Jul 3, 2024
3969c7b
fix issubclass.
sayakpaul Jul 3, 2024
6834f6f
num_fused_loras
sayakpaul Jul 3, 2024
7ad3ae4
fix
sayakpaul Jul 3, 2024
f312258
fix
sayakpaul Jul 3, 2024
f37d3c3
remove mapping
sayakpaul Jul 3, 2024
45f6813
up
sayakpaul Jul 3, 2024
c61a4c7
fix
sayakpaul Jul 3, 2024
34cb72a
style
sayakpaul Jul 3, 2024
90bff96
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 4, 2024
992b9c3
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 5, 2024
e97a737
resolve conflicts.
sayakpaul Jul 11, 2024
7c74f58
fix-copies
sayakpaul Jul 11, 2024
a9c94c1
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 15, 2024
32e8aa2
change to SD3TransformerLoRALoadersMixin
sayakpaul Jul 15, 2024
7dd9a4c
Apply suggestions from code review
sayakpaul Jul 15, 2024
94fb788
up
sayakpaul Jul 15, 2024
f4bc25f
handle wuerstchen
sayakpaul Jul 15, 2024
c5b109d
up
sayakpaul Jul 15, 2024
8d6db91
move lora to lora_pipeline.py
sayakpaul Jul 15, 2024
81414ec
up
sayakpaul Jul 15, 2024
af75abf
fix-copies
sayakpaul Jul 15, 2024
6626824
fix documentation.
sayakpaul Jul 15, 2024
a17646a
comment set_adapters().
sayakpaul Jul 15, 2024
07fd364
resolve conflicts.
sayakpaul Jul 23, 2024
11c6051
fix-copies
sayakpaul Jul 23, 2024
891260b
fix set_adapters() at the model level.
sayakpaul Jul 23, 2024
ab1926f
fix?
sayakpaul Jul 23, 2024
0293578
fix
sayakpaul Jul 23, 2024
9c868b9
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 23, 2024
aa67632
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 24, 2024
b5585d7
Merge branch 'main' into feat-lora-base-class
sayakpaul Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ specific language governing permissions and limitations under the License.

# LoRA

LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the UNet, text encoder or both. There are two classes for loading LoRA weights:
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the denoiser, text encoder or both. The denoiser usually corresponds to a UNet ([`UNet2DConditionModel`], for example) or a Transformer ([`SD3Transformer2DModel`], for example). There are several classes for loading LoRA weights:

- [`LoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`LoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

<Tip>

Expand All @@ -29,4 +32,16 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse

## StableDiffusionXLLoraLoaderMixin

[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin

## SD3LoraLoaderMixin

[[autodoc]] loaders.lora.SD3LoraLoaderMixin

## AmusedLoraLoaderMixin

[[autodoc]] loaders.lora.AmusedLoraLoaderMixin

## LoraBaseMixin

[[autodoc]] loaders.lora_base.LoraBaseMixin
10 changes: 5 additions & 5 deletions examples/amused/train_amused.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

import diffusers.optimization
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders import AmusedLoraLoaderMixin
from diffusers.utils import is_wandb_available


Expand Down Expand Up @@ -532,7 +532,7 @@ def save_model_hook(models, weights, output_dir):
weights.pop()

if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
LoraLoaderMixin.save_lora_weights(
AmusedLoraLoaderMixin.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
Expand Down Expand Up @@ -566,11 +566,11 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

if transformer is not None or text_encoder_ is not None:
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir)
AmusedLoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
LoraLoaderMixin.load_lora_into_transformer(
AmusedLoraLoaderMixin.load_lora_into_transformer(
lora_state_dict, network_alphas=network_alphas, transformer=transformer
)

Expand Down
17 changes: 15 additions & 2 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,18 @@ def text_encoder_attn_modules(text_encoder):

if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_sd3"] = ["SD3TransformerLoadersMixin"]

_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"]
_import_structure["lora"] = [
"AmusedLoraLoaderMixin",
"LoraLoaderMixin",
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]

Expand All @@ -69,12 +76,18 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_sd3 import SD3TransformerLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from .lora import (
AmusedLoraLoaderMixin,
LoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin

Expand Down
Loading
Loading