diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 3a4d21c6a019..7bb6a3dc6f51 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -12,10 +12,13 @@ specific language governing permissions and limitations under the License. # LoRA -LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the UNet, text encoder or both. There are two classes for loading LoRA weights: +LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the denoiser, text encoder or both. The denoiser usually corresponds to a UNet ([`UNet2DConditionModel`], for example) or a Transformer ([`SD3Transformer2DModel`], for example). There are several classes for loading LoRA weights: - [`LoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model. - [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`LoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model. +- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3). +- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. +- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. @@ -29,4 +32,16 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## StableDiffusionXLLoraLoaderMixin -[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin \ No newline at end of file +[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin + +## SD3LoraLoaderMixin + +[[autodoc]] loaders.lora.SD3LoraLoaderMixin + +## AmusedLoraLoaderMixin + +[[autodoc]] loaders.lora.AmusedLoraLoaderMixin + +## LoraBaseMixin + +[[autodoc]] loaders.lora_base.LoraBaseMixin \ No newline at end of file diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py index 3ec0503dfdfe..ede51775dd8f 100644 --- a/examples/amused/train_amused.py +++ b/examples/amused/train_amused.py @@ -41,7 +41,7 @@ import diffusers.optimization from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel -from diffusers.loaders import LoraLoaderMixin +from diffusers.loaders import AmusedLoraLoaderMixin from diffusers.utils import is_wandb_available @@ -532,7 +532,7 @@ def save_model_hook(models, weights, output_dir): weights.pop() if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None: - LoraLoaderMixin.save_lora_weights( + AmusedLoraLoaderMixin.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save, @@ -566,11 +566,11 @@ def load_model_hook(models, input_dir): raise ValueError(f"unexpected save model: {model.__class__}") if transformer is not None or text_encoder_ is not None: - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir) + AmusedLoraLoaderMixin.load_lora_into_text_encoder( lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ ) - LoraLoaderMixin.load_lora_into_transformer( + AmusedLoraLoaderMixin.load_lora_into_transformer( lora_state_dict, network_alphas=network_alphas, transformer=transformer ) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index ded53733f052..e905e4eac433 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -55,11 +55,18 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["transformer_sd3"] = ["SD3TransformerLoadersMixin"] + _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): _import_structure["single_file"] = ["FromSingleFileMixin"] - _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"] + _import_structure["lora"] = [ + "AmusedLoraLoaderMixin", + "LoraLoaderMixin", + "SD3LoraLoaderMixin", + "StableDiffusionXLLoraLoaderMixin", + ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] @@ -69,12 +76,18 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .transformer_sd3 import SD3TransformerLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): from .ip_adapter import IPAdapterMixin - from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin + from .lora import ( + AmusedLoraLoaderMixin, + LoraLoaderMixin, + SD3LoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 539a5637b7be..345c5ebdaccd 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -11,49 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy -import inspect import os -from pathlib import Path from typing import Callable, Dict, List, Optional, Union -import safetensors import torch -from huggingface_hub import model_info -from huggingface_hub.constants import HF_HUB_OFFLINE from huggingface_hub.utils import validate_hf_hub_args -from torch import nn -from ..models.modeling_utils import load_state_dict from ..utils import ( USE_PEFT_BACKEND, - _get_model_file, - convert_state_dict_to_diffusers, - convert_state_dict_to_peft, convert_unet_state_dict_to_peft, - delete_adapter_layers, get_adapter_name, get_peft_kwargs, - is_accelerate_available, is_peft_version, - is_transformers_available, logging, - recurse_remove_peft_layers, - scale_lora_layers, - set_adapter_layers, - set_weights_and_activate_adapters, ) +from .lora_base import LoraBaseMixin from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers -if is_transformers_available(): - from transformers import PreTrainedModel - - from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules - -if is_accelerate_available(): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" @@ -66,7 +41,7 @@ LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." -class LoraLoaderMixin: +class LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). @@ -74,8 +49,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - transformer_name = TRANSFORMER_NAME - num_fused_loras = 0 + is_unet_denoiser = True def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -214,64 +188,21 @@ def lora_state_dict( "framework": "pytorch", } - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - # Here we're relaxing the loading check to enable more Inference API - # friendliness where sometimes, it's not at all possible to automatically - # determine `weight_name`. - if weight_name is None: - weight_name = cls._best_guess_weight_name( - pretrained_model_name_or_path_or_dict, - file_extension=".safetensors", - local_files_only=local_files_only, - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except (IOError, safetensors.SafetensorError) as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - model_file = None - pass - - if model_file is None: - if weight_name is None: - weight_name = cls._best_guess_weight_name( - pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) network_alphas = None # TODO: replace it with a method from `state_dict_utils` @@ -292,81 +223,78 @@ def lora_state_dict( return state_dict, network_alphas - @classmethod - def _best_guess_weight_name( - cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False + # This method acts as an interface to distinguish between `fuse_unet` and `fuse_transformer` arguments + # properly and lets us reuse the `fuse_lora` method of the superclass. + def fuse_lora( + self, + fuse_unet: bool = True, + fuse_text_encoder: bool = True, + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, ): - if local_files_only or HF_HUB_OFFLINE: - raise ValueError("When using the offline mode, you must specify a `weight_name`.") - - targeted_files = [] - - if os.path.isfile(pretrained_model_name_or_path_or_dict): - return - elif os.path.isdir(pretrained_model_name_or_path_or_dict): - targeted_files = [ - f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) - ] - else: - files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings - targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] - if len(targeted_files) == 0: - return - - # "scheduler" does not correspond to a LoRA checkpoint. - # "optimizer" does not correspond to a LoRA checkpoint - # only top-level checkpoints are considered and not the other ones, hence "checkpoint". - unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} - targeted_files = list( - filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) - ) + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. - if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) - elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) + - if len(targeted_files) > 1: - raise ValueError( - f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." - ) - weight_name = targeted_files[0] - return weight_name + This is an experimental API. - @classmethod - def _optionally_disable_offloading(cls, _pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. + fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters. + fuse_text_encoder (`bool`, defaults to `True`): + Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` """ - is_model_cpu_offload = False - is_sequential_cpu_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + super().fuse_lora( + fuse_denoiser=fuse_unet, + fuse_text_encoder=fuse_text_encoder, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + ) - return (is_model_cpu_offload, is_sequential_cpu_offload) + # This method acts as an interface to distinguish between `unfuse_unet` and `unfuse_transformer` arguments + # properly and lets us reuse the `unfuse_lora` method of the superclass. + def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(unfuse_denoiser=unfuse_unet, unfuse_text_encoder=unfuse_text_encoder) @classmethod def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): @@ -403,233 +331,12 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline ) - @classmethod - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - See `LoRALinearLayer` for more details. - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - keys = list(state_dict.keys()) - prefix = cls.text_encoder_name if prefix is None else prefix - - # Safe prefix to check with. - if any(cls.text_encoder_name in key for key in keys): - # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [ - k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix - ] - network_alphas = { - k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - @classmethod - def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - network_alphas (`Dict[str, float]`): - See `LoRALinearLayer` for more details. - unet (`UNet2DConditionModel`): - The UNet model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] - network_alphas = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - if len(state_dict.keys()) > 0: - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - @property - def lora_scale(self) -> float: - # property function that returns the lora scale which can be set at run time by the pipeline. - # if _lora_scale has not been set, return 1 - return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - - def _remove_text_encoder_monkey_patch(self): - remove_method = recurse_remove_peft_layers - if hasattr(self, "text_encoder"): - remove_method(self.text_encoder) - # In case text encoder have no Lora attached - if getattr(self.text_encoder, "peft_config", None) is not None: - del self.text_encoder.peft_config - self.text_encoder._hf_peft_config_loaded = None - - if hasattr(self, "text_encoder_2"): - remove_method(self.text_encoder_2) - if getattr(self.text_encoder_2, "peft_config", None) is not None: - del self.text_encoder_2.peft_config - self.text_encoder_2._hf_peft_config_loaded = None - @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - transformer_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -659,24 +366,14 @@ def save_lora_weights( """ state_dict = {} - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - - if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`." - ) + if not (unet_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) + state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - - if transformer_lora_layers: - state_dict.update(pack_weights(transformer_lora_layers, "transformer")) + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( @@ -688,499 +385,10 @@ def pack_weights(layers, prefix): safe_serialization=safe_serialization, ) - @staticmethod - def write_lora_layers( - state_dict: Dict[str, torch.Tensor], - save_directory: str, - is_main_process: bool, - weight_name: str, - save_function: Callable, - safe_serialization: bool, - ): - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - if save_function is None: - if safe_serialization: - - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) - - else: - save_function = torch.save - - os.makedirs(save_directory, exist_ok=True) - - if weight_name is None: - if safe_serialization: - weight_name = LORA_WEIGHT_NAME_SAFE - else: - weight_name = LORA_WEIGHT_NAME - - save_path = Path(save_directory, weight_name).as_posix() - save_function(state_dict, save_path) - logger.info(f"Model weights saved in {save_path}") - - def unload_lora_weights(self): - """ - Unloads the LoRA parameters. - - Examples: - - ```python - >>> # Assuming `pipeline` is already loaded with the LoRA parameters. - >>> pipeline.unload_lora_weights() - >>> ... - ``` - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.unload_lora() - - # Safe to call the following regardless of LoRA. - self._remove_text_encoder_monkey_patch() - - def fuse_lora( - self, - fuse_unet: bool = True, - fuse_text_encoder: bool = True, - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters. - fuse_text_encoder (`bool`, defaults to `True`): - Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - from peft.tuners.tuners_utils import BaseTunerLayer - - if fuse_unet or fuse_text_encoder: - self.num_fused_loras += 1 - if self.num_fused_loras > 1: - logger.warning( - "The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.", - ) - - if fuse_unet: - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) - - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): - merge_kwargs = {"safe_merge": safe_fusing} - - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - if lora_scale != 1.0: - module.scale_layer(lora_scale) - - # For BC with previous PEFT versions, we need to check the signature - # of the `merge` method to see if it supports the `adapter_names` argument. - supported_merge_kwargs = list(inspect.signature(module.merge).parameters) - if "adapter_names" in supported_merge_kwargs: - merge_kwargs["adapter_names"] = adapter_names - elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: - raise ValueError( - "The `adapter_names` argument is not supported with your PEFT version. " - "Please upgrade to the latest version of PEFT. `pip install -U peft`" - ) - - module.merge(**merge_kwargs) - - if fuse_text_encoder: - if hasattr(self, "text_encoder"): - fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names) - if hasattr(self, "text_encoder_2"): - fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names) - - def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - """ - from peft.tuners.tuners_utils import BaseTunerLayer - - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - if unfuse_unet: - for module in unet.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() - - def unfuse_text_encoder_lora(text_encoder): - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() - - if unfuse_text_encoder: - if hasattr(self, "text_encoder"): - unfuse_text_encoder_lora(self.text_encoder) - if hasattr(self, "text_encoder_2"): - unfuse_text_encoder_lora(self.text_encoder_2) - - self.num_fused_loras -= 1 - - def set_adapters_for_text_encoder( - self, - adapter_names: Union[List[str], str], - text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 - text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, - ): - """ - Sets the adapter layers for the text encoder. - - Args: - adapter_names (`List[str]` or `str`): - The names of the adapters to use. - text_encoder (`torch.nn.Module`, *optional*): - The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` - attribute. - text_encoder_weights (`List[float]`, *optional*): - The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - def process_weights(adapter_names, weights): - # Expand weights into a list, one entry per adapter - # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] - if not isinstance(weights, list): - weights = [weights] * len(adapter_names) - - if len(adapter_names) != len(weights): - raise ValueError( - f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" - ) - - # Set None values to default of 1.0 - # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] - weights = [w if w is not None else 1.0 for w in weights] - - return weights - - adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names - text_encoder_weights = process_weights(adapter_names, text_encoder_weights) - text_encoder = text_encoder or getattr(self, "text_encoder", None) - if text_encoder is None: - raise ValueError( - "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." - ) - set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) - - def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): - """ - Disables the LoRA layers for the text encoder. - - Args: - text_encoder (`torch.nn.Module`, *optional*): - The text encoder module to disable the LoRA layers for. If `None`, it will try to get the - `text_encoder` attribute. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - text_encoder = text_encoder or getattr(self, "text_encoder", None) - if text_encoder is None: - raise ValueError("Text Encoder not found.") - set_adapter_layers(text_encoder, enabled=False) - - def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): - """ - Enables the LoRA layers for the text encoder. - - Args: - text_encoder (`torch.nn.Module`, *optional*): - The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` - attribute. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - text_encoder = text_encoder or getattr(self, "text_encoder", None) - if text_encoder is None: - raise ValueError("Text Encoder not found.") - set_adapter_layers(self.text_encoder, enabled=True) - - def set_adapters( - self, - adapter_names: Union[List[str], str], - adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, - ): - adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names - - adapter_weights = copy.deepcopy(adapter_weights) - - # Expand weights into a list, one entry per adapter - if not isinstance(adapter_weights, list): - adapter_weights = [adapter_weights] * len(adapter_names) - - if len(adapter_names) != len(adapter_weights): - raise ValueError( - f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" - ) - - # Decompose weights into weights for unet, text_encoder and text_encoder_2 - unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], [] - - list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} - all_adapters = { - adapter for adapters in list_adapters.values() for adapter in adapters - } # eg ["adapter1", "adapter2"] - invert_list_adapters = { - adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] - for adapter in all_adapters - } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} - - for adapter_name, weights in zip(adapter_names, adapter_weights): - if isinstance(weights, dict): - unet_lora_weight = weights.pop("unet", None) - text_encoder_lora_weight = weights.pop("text_encoder", None) - text_encoder_2_lora_weight = weights.pop("text_encoder_2", None) - - if len(weights) > 0: - raise ValueError( - f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}." - ) - - if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"): - logger.warning( - "Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2." - ) - - # warn if adapter doesn't have parts specified by adapter_weights - for part_weight, part_name in zip( - [unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight], - ["unet", "text_encoder", "text_encoder_2"], - ): - if part_weight is not None and part_name not in invert_list_adapters[adapter_name]: - logger.warning( - f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." - ) - - else: - unet_lora_weight = weights - text_encoder_lora_weight = weights - text_encoder_2_lora_weight = weights - - unet_lora_weights.append(unet_lora_weight) - text_encoder_lora_weights.append(text_encoder_lora_weight) - text_encoder_2_lora_weights.append(text_encoder_2_lora_weight) - - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - # Handle the UNET - unet.set_adapters(adapter_names, unet_lora_weights) - - # Handle the Text Encoder - if hasattr(self, "text_encoder"): - self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights) - if hasattr(self, "text_encoder_2"): - self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights) - - def disable_lora(self): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - # Disable unet adapters - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.disable_lora() - - # Disable text encoder adapters - if hasattr(self, "text_encoder"): - self.disable_lora_for_text_encoder(self.text_encoder) - if hasattr(self, "text_encoder_2"): - self.disable_lora_for_text_encoder(self.text_encoder_2) - - def enable_lora(self): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - # Enable unet adapters - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.enable_lora() - - # Enable text encoder adapters - if hasattr(self, "text_encoder"): - self.enable_lora_for_text_encoder(self.text_encoder) - if hasattr(self, "text_encoder_2"): - self.enable_lora_for_text_encoder(self.text_encoder_2) - - def delete_adapters(self, adapter_names: Union[List[str], str]): - """ - Args: - Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). - adapter_names (`Union[List[str], str]`): - The names of the adapter to delete. Can be a single string or a list of strings - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - if isinstance(adapter_names, str): - adapter_names = [adapter_names] - - # Delete unet adapters - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.delete_adapters(adapter_names) - - for adapter_name in adapter_names: - # Delete text encoder adapters - if hasattr(self, "text_encoder"): - delete_adapter_layers(self.text_encoder, adapter_name) - if hasattr(self, "text_encoder_2"): - delete_adapter_layers(self.text_encoder_2, adapter_name) - - def get_active_adapters(self) -> List[str]: - """ - Gets the list of the current active adapters. - - Example: - - ```python - from diffusers import DiffusionPipeline - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - ).to("cuda") - pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") - pipeline.get_active_adapters() - ``` - """ - if not USE_PEFT_BACKEND: - raise ValueError( - "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" - ) - - from peft.tuners.tuners_utils import BaseTunerLayer - - active_adapters = [] - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for module in unet.modules(): - if isinstance(module, BaseTunerLayer): - active_adapters = module.active_adapters - break - - return active_adapters - - def get_list_adapters(self) -> Dict[str, List[str]]: - """ - Gets the current list of all available adapters in the pipeline. - """ - if not USE_PEFT_BACKEND: - raise ValueError( - "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" - ) - - set_adapters = {} - - if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"): - set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys()) - - if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): - set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) - - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"): - set_adapters[self.unet_name] = list(self.unet.peft_config.keys()) - - return set_adapters - - def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: - """ - Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case - you want to load multiple adapters and free some GPU memory. - - Args: - adapter_names (`List[str]`): - List of adapters to send device to. - device (`Union[torch.device, str, int]`): - Device to send the adapters to. Can be either a torch device, a str or an integer. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft.tuners.tuners_utils import BaseTunerLayer - - # Handle the UNET - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for unet_module in unet.modules(): - if isinstance(unet_module, BaseTunerLayer): - for adapter_name in adapter_names: - unet_module.lora_A[adapter_name].to(device) - unet_module.lora_B[adapter_name].to(device) - # this is a param, not a module, so device placement is not in-place -> re-assign - if hasattr(unet_module, "lora_magnitude_vector") and unet_module.lora_magnitude_vector is not None: - unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[ - adapter_name - ].to(device) - - # Handle the text encoder - modules_to_process = [] - if hasattr(self, "text_encoder"): - modules_to_process.append(self.text_encoder) - - if hasattr(self, "text_encoder_2"): - modules_to_process.append(self.text_encoder_2) - - for text_encoder in modules_to_process: - # loop over submodules - for text_encoder_module in text_encoder.modules(): - if isinstance(text_encoder_module, BaseTunerLayer): - for adapter_name in adapter_names: - text_encoder_module.lora_A[adapter_name].to(device) - text_encoder_module.lora_B[adapter_name].to(device) - # this is a param, not a module, so device placement is not in-place -> re-assign - if ( - hasattr(text_encoder_module, "lora_magnitude_vector") - and text_encoder_module.lora_magnitude_vector is not None - ): - text_encoder_module.lora_magnitude_vector[ - adapter_name - ] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device) - class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): - """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL""" + """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific + to SDXL ([`StableDiffusionXLPipeline`].)""" # Override to properly handle the loading and unloading of the additional text encoder. def load_lora_weights( @@ -1299,54 +507,115 @@ def save_lora_weights( """ state_dict = {} - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." ) - if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + if unet_lora_layers: + state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + + if text_encoder_2_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + +class SD3LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`SD3Transformer2DModel`]. Specific to [`StableDiffusion3Pipeline`]. + """ + + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + is_transformer_denoiser = True + + # This method acts as an interface to distinguish between `fuse_unet` and `fuse_transformer` arguments + # properly and lets us reuse the `fuse_lora` method of the superclass. + def fuse_lora( + self, + fuse_transformer: bool = True, + fuse_text_encoder: bool = True, + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters. + fuse_text_encoder (`bool`, defaults to `True`): + Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + Example: - if text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + ```py + from diffusers import DiffusionPipeline + import torch - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "nerijs/pixel-art-medium-128-v0.1", + weight_name="pixel-art-medium-128-v0.1.safetensors", + adapter_name="pixel", + ) + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + fuse_denoiser=fuse_transformer, + fuse_text_encoder=fuse_text_encoder, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, ) - def _remove_text_encoder_monkey_patch(self): - recurse_remove_peft_layers(self.text_encoder) - # TODO: @younesbelkada handle this in transformers side - if getattr(self.text_encoder, "peft_config", None) is not None: - del self.text_encoder.peft_config - self.text_encoder._hf_peft_config_loaded = None + # This method acts as an interface to distinguish between `unfuse_unet` and `unfuse_transformer` arguments + # properly and lets us reuse the `unfuse_lora` method of the superclass. + def unfuse_lora(self, unfuse_transformer: bool = True, unfuse_text_encoder: bool = True): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). - recurse_remove_peft_layers(self.text_encoder_2) - if getattr(self.text_encoder_2, "peft_config", None) is not None: - del self.text_encoder_2.peft_config - self.text_encoder_2._hf_peft_config_loaded = None + + This is an experimental API. -class SD3LoraLoaderMixin: - r""" - Load LoRA layers into [`SD3Transformer2DModel`]. - """ + - transformer_name = TRANSFORMER_NAME - num_fused_loras = 0 + Args: + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(unfuse_denoiser=unfuse_transformer, unfuse_text_encoder=unfuse_text_encoder) def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -1392,6 +661,30 @@ def load_lora_weights( _pipeline=self, ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=None, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + @classmethod @validate_hf_hub_args def lora_state_dict( @@ -1447,7 +740,7 @@ def lora_state_dict( """ # Load the main state dict first which has the LoRA layers for either of - # UNet and text encoder or both. + # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -1469,51 +762,21 @@ def lora_state_dict( "framework": "pytorch", } - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except (IOError, safetensors.SafetensorError) as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - model_file = None - pass - - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) return state_dict @@ -1615,6 +878,12 @@ def save_lora_weights( Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1628,24 +897,19 @@ def save_lora_weights( """ state_dict = {} - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." ) if transformer_lora_layers: - state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) if text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) # Save the model cls.write_lora_layers( @@ -1657,164 +921,152 @@ def pack_weights(layers, prefix): safe_serialization=safe_serialization, ) - @staticmethod - def write_lora_layers( - state_dict: Dict[str, torch.Tensor], - save_directory: str, - is_main_process: bool, - weight_name: str, - save_function: Callable, - safe_serialization: bool, - ): - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - if save_function is None: - if safe_serialization: - - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) - else: - save_function = torch.save - - os.makedirs(save_directory, exist_ok=True) - - if weight_name is None: - if safe_serialization: - weight_name = LORA_WEIGHT_NAME_SAFE - else: - weight_name = LORA_WEIGHT_NAME - - save_path = Path(save_directory, weight_name).as_posix() - save_function(state_dict, save_path) - logger.info(f"Model weights saved in {save_path}") - - def unload_lora_weights(self): - """ - Unloads the LoRA parameters. +# The reason why we subclass from `LoraLoaderMixin` here is because Amused initially +# relied on `LoraLoaderMixin` for its LoRA support. +class AmusedLoraLoaderMixin(LoraLoaderMixin): + is_unet_denoiser = False + is_transformer_denoiser = True + transformer_name = TRANSFORMER_NAME - Examples: + @classmethod + def load_lora_into_unet(cls, **kwargs): + raise NotImplementedError("`load_lora_into_unet()` not implemented for `AmusedLoraLoaderMixin`.") - ```python - >>> # Assuming `pipeline` is already loaded with the LoRA parameters. - >>> pipeline.unload_lora_weights() - >>> ... - ``` + @classmethod + def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): """ - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - recurse_remove_peft_layers(transformer) - if hasattr(transformer, "peft_config"): - del transformer.peft_config + This will load the LoRA layers specified in `state_dict` into `transformer`. - @classmethod - # Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading - def _optionally_disable_offloading(cls, _pipeline): + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. - """ - is_model_cpu_offload = False - is_sequential_cpu_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + keys = list(state_dict.keys()) - return (is_model_cpu_offload, is_sequential_cpu_offload) + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } - def fuse_lora( - self, - fuse_transformer: bool = True, - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] + network_alphas = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } - + if len(state_dict.keys()) > 0: + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) - This is an experimental API. + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] - + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) - Args: - fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) - Example: + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - ```py - from diffusers import DiffusionPipeline - import torch + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights( - "nerijs/pixel-art-medium-128-v0.1", - weight_name="pixel-art-medium-128-v0.1.safetensors", - adapter_name="pixel", - ) - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - if fuse_transformer: - self.num_fused_loras += 1 + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) - if fuse_transformer: - transformer = ( - getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - ) - transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> - def unfuse_lora(self, unfuse_transformer: bool = True): + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). - - + Save the LoRA parameters corresponding to the UNet and text encoder. - This is an experimental API. + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} - + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - Args: - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters. - """ - from peft.tuners.tuners_utils import BaseTunerLayer + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if unfuse_transformer: - for module in transformer.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - self.num_fused_loras -= 1 + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py new file mode 100644 index 000000000000..7f80797137d4 --- /dev/null +++ b/src/diffusers/loaders/lora_base.py @@ -0,0 +1,938 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import safetensors +import torch +import torch.nn as nn +from huggingface_hub import model_info +from huggingface_hub.constants import HF_HUB_OFFLINE + +from ..models.modeling_utils import load_state_dict +from ..utils import ( + USE_PEFT_BACKEND, + _get_model_file, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + delete_adapter_layers, + get_adapter_name, + get_peft_kwargs, + is_accelerate_available, + is_peft_version, + is_transformers_available, + logging, + recurse_remove_peft_layers, + scale_lora_layers, + set_adapter_layers, + set_weights_and_activate_adapters, +) + + +if is_transformers_available(): + from transformers import PreTrainedModel + + from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +class LoraBaseMixin: + """Utility class for handling LoRAs.""" + + is_unet_denoiser = False + is_transformer_denoiser = False + num_fused_loras = 0 + + def _remove_text_encoder_monkey_patch(self): + if hasattr(self, "text_encoder"): + recurse_remove_peft_layers(self.text_encoder) + # TODO: @younesbelkada handle this in transformers side + if getattr(self.text_encoder, "peft_config", None) is not None: + del self.text_encoder.peft_config + self.text_encoder._hf_peft_config_loaded = None + + if hasattr(self, "text_encoder_2"): + recurse_remove_peft_layers(self.text_encoder_2) + if getattr(self.text_encoder_2, "peft_config", None) is not None: + del self.text_encoder_2.peft_config + self.text_encoder_2._hf_peft_config_loaded = None + + @classmethod + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + @classmethod + def _fetch_state_dict( + cls, + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + resume_download, + proxies, + token, + revision, + subfolder, + user_agent, + allow_pickle, + ): + from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + pass + + if model_file is None: + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict + + @classmethod + def _best_guess_weight_name( + cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False + ): + from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + targeted_files = [] + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return + elif os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [ + f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) + ] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + if len(targeted_files) == 0: + return + + # "scheduler" does not correspond to a LoRA checkpoint. + # "optimizer" does not correspond to a LoRA checkpoint + # only top-level checkpoints are considered and not the other ones, hence "checkpoint". + unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} + targeted_files = list( + filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) + ) + + if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) + elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) + + if len(targeted_files) > 1: + raise ValueError( + f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." + ) + weight_name = targeted_files[0] + return weight_name + + def load_lora_weights(self, **kwargs): + raise NotImplementedError("`load_lora_weights()` is not implemented.") + + @classmethod + def save_lora_weights(cls, **kwargs): + raise NotImplementedError("`save_lora_weights()` not implemented.") + + @classmethod + def lora_state_dict(cls, **kwargs): + raise NotImplementedError("`lora_state_dict()` is not implemented.") + + @classmethod + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + def unload_lora_weights(self): + """ + Unloads the LoRA parameters. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if self.is_unet_denoiser: + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.unload_lora() + elif self.is_transformer_denoiser: + transformer = ( + getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + ) + transformer.unload_lora() + else: + raise ValueError("No valid denoiser found in the network.") + + # Safe to call the following regardless of LoRA. + self._remove_text_encoder_monkey_patch() + + def fuse_lora( + self, + fuse_denoiser: bool = True, + fuse_text_encoder: bool = True, + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + fuse_denoiser (`bool`, defaults to `True`): + Whether to fuse the denoiser (UNet, Transformer, etc.) LoRA parameters. + fuse_text_encoder (`bool`, defaults to `True`): + Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + fuse_unet = True if fuse_denoiser and self.is_unet_denoiser else False + fuse_transformer = True if fuse_denoiser and self.is_transformer_denoiser else False + + if fuse_unet: + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + elif fuse_transformer: + transformer = ( + getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + ) + transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): + merge_kwargs = {"safe_merge": safe_fusing} + + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) + + # For BC with previous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. " + "Please upgrade to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) + + if fuse_text_encoder: + if hasattr(self, "text_encoder"): + fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names) + if hasattr(self, "text_encoder_2"): + fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names) + + if fuse_denoiser or fuse_text_encoder: + self.num_fused_loras += 1 + + def unfuse_lora(self, unfuse_denoiser: bool = True, unfuse_text_encoder: bool = True): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + unfuse_unet = True if unfuse_denoiser and self.is_unet_denoiser else False + unfuse_transformer = True if unfuse_denoiser and self.is_transformer_denoiser else False + + if unfuse_unet: + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for module in unet.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + elif unfuse_transformer: + transformer = ( + getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + ) + for module in transformer.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + def unfuse_text_encoder_lora(text_encoder): + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + if unfuse_text_encoder: + if hasattr(self, "text_encoder"): + unfuse_text_encoder_lora(self.text_encoder) + if hasattr(self, "text_encoder_2"): + unfuse_text_encoder_lora(self.text_encoder_2) + + self.num_fused_loras -= 1 + + def set_adapters_for_text_encoder( + self, + adapter_names: Union[List[str], str], + text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 + text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, + ): + """ + Sets the adapter layers for the text encoder. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + text_encoder_weights (`List[float]`, *optional*): + The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + def process_weights(adapter_names, weights): + # Expand weights into a list, one entry per adapter + # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" + ) + + # Set None values to default of 1.0 + # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] + weights = [w if w is not None else 1.0 for w in weights] + + return weights + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + text_encoder_weights = process_weights(adapter_names, text_encoder_weights) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError( + "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." + ) + set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) + + def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + """ + Disables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to disable the LoRA layers for. If `None`, it will try to get the + `text_encoder` attribute. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(text_encoder, enabled=False) + + def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + """ + Enables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(self.text_encoder, enabled=True) + + def set_adapters( + self, + adapter_names: Union[List[str], str], + adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, + ): + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + adapter_weights = copy.deepcopy(adapter_weights) + + # Expand weights into a list, one entry per adapter + if not isinstance(adapter_weights, list): + adapter_weights = [adapter_weights] * len(adapter_names) + + if len(adapter_names) != len(adapter_weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" + ) + + # Decompose weights into weights for unet, text_encoder and text_encoder_2 + denoiser_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], [] + + list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} + all_adapters = { + adapter for adapters in list_adapters.values() for adapter in adapters + } # eg ["adapter1", "adapter2"] + invert_list_adapters = { + adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] + for adapter in all_adapters + } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + + denoiser_name = "unet" if self.is_unet_denoiser else "transformer" + for adapter_name, weights in zip(adapter_names, adapter_weights): + if isinstance(weights, dict): + denoiser_lora_weight = weights.pop(denoiser_name, None) + text_encoder_lora_weight = weights.pop("text_encoder", None) + text_encoder_2_lora_weight = weights.pop("text_encoder_2", None) + + if len(weights) > 0: + raise ValueError( + f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}." + ) + + if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"): + logger.warning( + "Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2." + ) + + # warn if adapter doesn't have parts specified by adapter_weights + for part_weight, part_name in zip( + [denoiser_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight], + [denoiser_name, "text_encoder", "text_encoder_2"], + ): + if part_weight is not None and part_name not in invert_list_adapters[adapter_name]: + logger.warning( + f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." + ) + + else: + denoiser_lora_weight = weights + text_encoder_lora_weight = weights + text_encoder_2_lora_weight = weights + + denoiser_lora_weights.append(denoiser_lora_weight) + text_encoder_lora_weights.append(text_encoder_lora_weight) + text_encoder_2_lora_weights.append(text_encoder_2_lora_weight) + + if denoiser_name == "unet": + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + # Handle the UNET + unet.set_adapters(adapter_names, denoiser_lora_weights) + else: + transformer = ( + getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + ) + # Handle the UNET + transformer.set_adapters(adapter_names, denoiser_lora_weights) + + # Handle the Text Encoder + if hasattr(self, "text_encoder"): + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights) + if hasattr(self, "text_encoder_2"): + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights) + + def disable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # Disable denoiser adapters + if self.is_unet_denoiser: + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.disable_lora() + else: + transformer = ( + getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + ) + transformer.disable_lora() + + # Disable text encoder adapters + if hasattr(self, "text_encoder"): + self.disable_lora_for_text_encoder(self.text_encoder) + if hasattr(self, "text_encoder_2"): + self.disable_lora_for_text_encoder(self.text_encoder_2) + + def enable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # Enable unet adapters + if self.is_unet_denoiser: + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.enable_lora() + else: + transformer = ( + getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + ) + transformer.enable_lora() + + # Enable text encoder adapters + if hasattr(self, "text_encoder"): + self.enable_lora_for_text_encoder(self.text_encoder) + if hasattr(self, "text_encoder_2"): + self.enable_lora_for_text_encoder(self.text_encoder_2) + + def delete_adapters(self, adapter_names: Union[List[str], str]): + """ + Args: + Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). + adapter_names (`Union[List[str], str]`): + The names of the adapter to delete. Can be a single string or a list of strings + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + # Delete unet adapters + if self.is_unet_denoiser: + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.delete_adapters(adapter_names) + else: + transformer = ( + getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + ) + transformer.delete_adapters(adapter_names) + + for adapter_name in adapter_names: + # Delete text encoder adapters + if hasattr(self, "text_encoder"): + delete_adapter_layers(self.text_encoder, adapter_name) + if hasattr(self, "text_encoder_2"): + delete_adapter_layers(self.text_encoder_2, adapter_name) + + def get_active_adapters(self) -> List[str]: + """ + Gets the list of the current active adapters. + + Example: + + ```python + from diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + ).to("cuda") + pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") + pipeline.get_active_adapters() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + from peft.tuners.tuners_utils import BaseTunerLayer + + active_adapters = [] + if self.is_unet_denoiser: + denoiser = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + else: + denoiser = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + + for module in denoiser.modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapters + break + + return active_adapters + + def get_list_adapters(self) -> Dict[str, List[str]]: + """ + Gets the current list of all available adapters in the pipeline. + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + set_adapters = {} + + if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"): + set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys()) + + if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): + set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) + + if self.is_unet_denoiser: + denoiser = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + denoiser_name = self.unet_name + else: + denoiser = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + denoiser_name = self.transformer_name + + if hasattr(self, denoiser_name) and hasattr(denoiser, "peft_config"): + set_adapters[denoiser_name] = ( + list(self.unet.peft_config.keys()) + if self.is_unet_denoiser + else list(self.transformer.peft_config.keys()) + ) + + return set_adapters + + def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: + """ + Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case + you want to load multiple adapters and free some GPU memory. + + Args: + adapter_names (`List[str]`): + List of adapters to send device to. + device (`Union[torch.device, str, int]`): + Device to send the adapters to. Can be either a torch device, a str or an integer. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + # Handle the denoiser + if self.is_unet_denoiser: + denoiser = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + else: + denoiser = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + + for denoiser_module in denoiser.modules(): + if isinstance(denoiser_module, BaseTunerLayer): + for adapter_name in adapter_names: + denoiser_module.lora_A[adapter_name].to(device) + denoiser_module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + if ( + hasattr(denoiser_module, "lora_magnitude_vector") + and denoiser_module.lora_magnitude_vector is not None + ): + denoiser_module.lora_magnitude_vector[adapter_name] = denoiser_module.lora_magnitude_vector[ + adapter_name + ].to(device) + + # Handle the text encoder + modules_to_process = [] + if hasattr(self, "text_encoder"): + modules_to_process.append(self.text_encoder) + + if hasattr(self, "text_encoder_2"): + modules_to_process.append(self.text_encoder_2) + + for text_encoder in modules_to_process: + # loop over submodules + for text_encoder_module in text_encoder.modules(): + if isinstance(text_encoder_module, BaseTunerLayer): + for adapter_name in adapter_names: + text_encoder_module.lora_A[adapter_name].to(device) + text_encoder_module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + if ( + hasattr(text_encoder_module, "lora_magnitude_vector") + and text_encoder_module.lora_magnitude_vector is not None + ): + text_encoder_module.lora_magnitude_vector[ + adapter_name + ] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device) + + @staticmethod + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + @staticmethod + def write_lora_layers( + state_dict: Dict[str, torch.Tensor], + save_directory: str, + is_main_process: bool, + weight_name: str, + save_function: Callable, + safe_serialization: bool, + ): + from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + save_path = Path(save_directory, weight_name).as_posix() + save_function(state_dict, save_path) + logger.info(f"Model weights saved in {save_path}") + + @property + def lora_scale(self) -> float: + # property function that returns the lora scale which can be set at run time by the pipeline. + # if _lora_scale has not been set, return 1 + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py new file mode 100644 index 000000000000..1f351132684f --- /dev/null +++ b/src/diffusers/loaders/transformer_sd3.py @@ -0,0 +1,261 @@ +import inspect +from functools import partial +from typing import Dict, List, Optional, Union + +import torch.nn as nn + +from ..utils import ( + USE_PEFT_BACKEND, + delete_adapter_layers, + is_accelerate_available, + logging, + set_adapter_layers, + set_weights_and_activate_adapters, +) +from .lora import TEXT_ENCODER_NAME, TRANSFORMER_NAME + + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +class SD3TransformerLoadersMixin: + """ + Load LoRA layers into a [`SD3Transformer2DModel`]. + """ + + text_encoder_name = TEXT_ENCODER_NAME + transformer_name = TRANSFORMER_NAME + + @classmethod + # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.fuse_lora + def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `fuse_lora()`.") + + self.lora_scale = lora_scale + self._safe_fusing = safe_fusing + self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin._fuse_lora_apply + def _fuse_lora_apply(self, module, adapter_names=None): + from peft.tuners.tuners_utils import BaseTunerLayer + + merge_kwargs = {"safe_merge": self._safe_fusing} + + if isinstance(module, BaseTunerLayer): + if self.lora_scale != 1.0: + module.scale_layer(self.lora_scale) + + # For BC with prevous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" + " to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.unfuse_lora + def unfuse_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `unfuse_lora()`.") + self.apply(self._unfuse_lora_apply) + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin._unfuse_lora_apply + def _unfuse_lora_apply(self, module): + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(module, BaseTunerLayer): + module.unmerge() + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.unload_lora + def unload_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `unload_lora()`.") + + from ..utils import recurse_remove_peft_layers + + recurse_remove_peft_layers(self) + if hasattr(self, "peft_config"): + del self.peft_config + + # This class is almost the same but it doesn't do `_maybe_expand_lora_scales()` yet. We will work on adding + # this support in a future PR. + def set_adapters( + self, + adapter_names: Union[List[str], str], + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, + ): + """ + Set the currently active adapters for use in the Transformer. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + adapter_weights (`Union[List[float], float]`, *optional*): + The adapter(s) weights to use with the Transformer. If `None`, the weights are set to `1.0` for all the + adapters. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `set_adapters()`.") + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # Expand weights into a list, one entry per adapter + # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." + ) + + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + set_weights_and_activate_adapters(self, adapter_names, weights) + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.disable_lora with UNet->Transformer + def disable_lora(self): + """ + Disable the Transformer's active LoRA layers. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.disable_lora() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + set_adapter_layers(self, enabled=False) + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.enable_lora with UNet->Transformer + def enable_lora(self): + """ + Enable the Transformer's active LoRA layers. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.enable_lora() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + set_adapter_layers(self, enabled=True) + + # Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.delete_adapters with UNet->Transformer + def delete_adapters(self, adapter_names: Union[List[str], str]): + """ + Delete an adapter's LoRA layers from the Transformer. + + Args: + adapter_names (`Union[List[str], str]`): + The names (single string or list of strings) of the adapter to delete. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" + ) + pipeline.delete_adapters("cinematic") + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for adapter_name in adapter_names: + delete_adapter_layers(self, adapter_name) + + # Pop also the corresponding adapter from the config + if hasattr(self, "peft_config"): + self.peft_config.pop(adapter_name, None) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 58c9c0e60d17..f20e3645ea9c 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -362,7 +362,7 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter return is_model_cpu_offload, is_sequential_cpu_offload @classmethod - # Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading + # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 1b9126b3b849..71fd22fd1943 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -13,15 +13,13 @@ # limitations under the License. -import inspect -from functools import partial from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3TransformerLoadersMixin from ...models.attention import JointTransformerBlock from ...models.attention_processor import Attention, AttentionProcessor from ...models.modeling_utils import ModelMixin @@ -34,7 +32,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SD3Transformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3TransformerLoadersMixin +): """ The Transformer model introduced in Stable Diffusion 3. @@ -241,47 +241,6 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for `fuse_lora()`.") - - self.lora_scale = lora_scale - self._safe_fusing = safe_fusing - self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) - - def _fuse_lora_apply(self, module, adapter_names=None): - from peft.tuners.tuners_utils import BaseTunerLayer - - merge_kwargs = {"safe_merge": self._safe_fusing} - - if isinstance(module, BaseTunerLayer): - if self.lora_scale != 1.0: - module.scale_layer(self.lora_scale) - - # For BC with prevous PEFT versions, we need to check the signature - # of the `merge` method to see if it supports the `adapter_names` argument. - supported_merge_kwargs = list(inspect.signature(module.merge).parameters) - if "adapter_names" in supported_merge_kwargs: - merge_kwargs["adapter_names"] = adapter_names - elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: - raise ValueError( - "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" - " to the latest version of PEFT. `pip install -U peft`" - ) - - module.merge(**merge_kwargs) - - def unfuse_lora(self): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for `unfuse_lora()`.") - self.apply(self._unfuse_lora_apply) - - def _unfuse_lora_apply(self, module): - from peft.tuners.tuners_utils import BaseTunerLayer - - if isinstance(module, BaseTunerLayer): - module.unmerge() - def forward( self, hidden_states: torch.FloatTensor, diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 7542f895ccf2..f5246027a49d 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -30,9 +30,12 @@ from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( + USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -346,6 +349,7 @@ def encode_prompt( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" @@ -391,9 +395,22 @@ def encode_prompt( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) @@ -496,6 +513,16 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 3980d323c94c..bacbf8b66577 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -29,9 +29,12 @@ from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( + USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -329,6 +332,7 @@ def encode_prompt( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" @@ -374,9 +378,22 @@ def encode_prompt( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) @@ -479,6 +496,16 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( @@ -787,6 +814,9 @@ def __call__( device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) ( prompt_embeds, negative_prompt_embeds, @@ -808,6 +838,7 @@ def __call__( clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 093616f5432d..c67dada0b86c 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -25,13 +25,17 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( + USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -346,6 +350,7 @@ def encode_prompt( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" @@ -391,9 +396,22 @@ def encode_prompt( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) @@ -496,6 +514,16 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 2d4069d720ef..9ce559be7f06 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -12,377 +12,55 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import sys -import tempfile import unittest -import numpy as np -import torch -from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel - from diffusers import ( - AutoencoderKL, FlowMatchEulerDiscreteScheduler, - SD3Transformer2DModel, StableDiffusion3Pipeline, ) from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device if is_peft_available(): - from peft import LoraConfig - from peft.utils import get_peft_model_state_dict + pass sys.path.append(".") -from utils import check_if_lora_correctly_set # noqa: E402 +from utils import PeftLoraLoaderMixinTests # noqa: E402 @require_peft_backend -class SD3LoRATests(unittest.TestCase): +class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = StableDiffusion3Pipeline - - def get_dummy_components(self): - torch.manual_seed(0) - transformer = SD3Transformer2DModel( - sample_size=32, - patch_size=1, - in_channels=4, - num_layers=1, - attention_head_dim=8, - num_attention_heads=4, - caption_projection_dim=32, - joint_attention_dim=32, - pooled_projection_dim=64, - out_channels=4, - ) - clip_text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - hidden_act="gelu", - projection_dim=32, - ) - - torch.manual_seed(0) - text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config) - - torch.manual_seed(0) - text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config) - - text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - - torch.manual_seed(0) - vae = AutoencoderKL( - sample_size=32, - in_channels=3, - out_channels=3, - block_out_channels=(4,), - layers_per_block=1, - latent_channels=4, - norm_num_groups=1, - use_quant_conv=False, - use_post_quant_conv=False, - shift_factor=0.0609, - scaling_factor=1.5035, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return { - "scheduler": scheduler, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - "text_encoder_3": text_encoder_3, - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "tokenizer_3": tokenizer_3, - "transformer": transformer, - "vae": vae, - } - - def get_dummy_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - inputs = { - "prompt": "A painting of a squirrel eating a burger", - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 5.0, - "output_type": "np", - } - return inputs - - def get_lora_config_for_transformer(self): - lora_config = LoraConfig( - r=4, - lora_alpha=4, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], - init_lora_weights=False, - use_dora=False, - ) - return lora_config - - def get_lora_config_for_text_encoders(self): - text_lora_config = LoraConfig( - r=4, - lora_alpha=4, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) - return text_lora_config - - def test_simple_inference_with_transformer_lora_save_load(self): - components = self.get_dummy_components() - transformer_config = self.get_lora_config_for_transformer() - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(torch_device) - - pipe.transformer.add_adapter(transformer_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - inputs = self.get_dummy_inputs(torch_device) - images_lora = pipe(**inputs).images - - with tempfile.TemporaryDirectory() as tmpdirname: - transformer_state_dict = get_peft_model_state_dict(pipe.transformer) - - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - transformer_lora_layers=transformer_state_dict, - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.unload_lora_weights() - - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - inputs = self.get_dummy_inputs(torch_device) - images_lora_from_pretrained = pipe(**inputs).images - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_clip_encoders_lora_save_load(self): - components = self.get_dummy_components() - transformer_config = self.get_lora_config_for_transformer() - text_encoder_config = self.get_lora_config_for_text_encoders() - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(torch_device) - - pipe.transformer.add_adapter(transformer_config) - pipe.text_encoder.add_adapter(text_encoder_config) - pipe.text_encoder_2.add_adapter(text_encoder_config) - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2.") - - inputs = self.get_dummy_inputs(torch_device) - images_lora = pipe(**inputs).images - - with tempfile.TemporaryDirectory() as tmpdirname: - transformer_state_dict = get_peft_model_state_dict(pipe.transformer) - text_encoder_one_state_dict = get_peft_model_state_dict(pipe.text_encoder) - text_encoder_two_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - transformer_lora_layers=transformer_state_dict, - text_encoder_lora_layers=text_encoder_one_state_dict, - text_encoder_2_lora_layers=text_encoder_two_state_dict, - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.unload_lora_weights() - - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - inputs = self.get_dummy_inputs(torch_device) - images_lora_from_pretrained = pipe(**inputs).images - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two") - - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_transformer_lora_and_scale(self): - components = self.get_dummy_components() - transformer_lora_config = self.get_lora_config_for_transformer() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_no_lora = pipe(**inputs).images - - pipe.transformer.add_adapter(transformer_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - - inputs = self.get_dummy_inputs(torch_device) - output_lora = pipe(**inputs).images - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - inputs = self.get_dummy_inputs(torch_device) - output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - - inputs = self.get_dummy_inputs(torch_device) - output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", - ) - - def test_simple_inference_with_clip_encoders_lora_and_scale(self): - components = self.get_dummy_components() - transformer_lora_config = self.get_lora_config_for_transformer() - text_encoder_config = self.get_lora_config_for_text_encoders() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_no_lora = pipe(**inputs).images - - pipe.transformer.add_adapter(transformer_lora_config) - pipe.text_encoder.add_adapter(text_encoder_config) - pipe.text_encoder_2.add_adapter(text_encoder_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two") - - inputs = self.get_dummy_inputs(torch_device) - output_lora = pipe(**inputs).images - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - inputs = self.get_dummy_inputs(torch_device) - output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - - inputs = self.get_dummy_inputs(torch_device) - output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", - ) - - def test_simple_inference_with_transformer_fused(self): - components = self.get_dummy_components() - transformer_lora_config = self.get_lora_config_for_transformer() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_no_lora = pipe(**inputs).images - - pipe.transformer.add_adapter(transformer_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - - pipe.fuse_lora() - # Fusing should still keep the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - - inputs = self.get_dummy_inputs(torch_device) - ouput_fused = pipe(**inputs).images - self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) - - def test_simple_inference_with_transformer_fused_with_no_fusion(self): - components = self.get_dummy_components() - transformer_lora_config = self.get_lora_config_for_transformer() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_no_lora = pipe(**inputs).images - - pipe.transformer.add_adapter(transformer_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - inputs = self.get_dummy_inputs(torch_device) - ouput_lora = pipe(**inputs).images - - pipe.fuse_lora() - # Fusing should still keep the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - - inputs = self.get_dummy_inputs(torch_device) - ouput_fused = pipe(**inputs).images - self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) - self.assertTrue( - np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3), - "Fused lora output should be changed when LoRA isn't fused but still effective.", - ) - - def test_simple_inference_with_transformer_fuse_unfuse(self): - components = self.get_dummy_components() - transformer_lora_config = self.get_lora_config_for_transformer() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output_no_lora = pipe(**inputs).images - - pipe.transformer.add_adapter(transformer_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - - pipe.fuse_lora() - # Fusing should still keep the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - inputs = self.get_dummy_inputs(torch_device) - ouput_fused = pipe(**inputs).images - self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) - - pipe.unfuse_lora() - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - inputs = self.get_dummy_inputs(torch_device) - output_unfused_lora = pipe(**inputs).images - self.assertTrue( - np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + transformer_kwargs = { + "sample_size": 32, + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "pooled_projection_dim": 64, + "out_channels": 4, + } + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 4, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, + } + has_three_text_encoders = True @require_torch_gpu def test_sd3_lora(self): diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 9a07727db931..938a46c2db10 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -19,12 +19,14 @@ import numpy as np import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import ( AutoencoderKL, DDIMScheduler, + FlowMatchEulerDiscreteScheduler, LCMScheduler, + SD3Transformer2DModel, UNet2DConditionModel, ) from diffusers.utils.import_utils import is_peft_available @@ -71,28 +73,47 @@ class PeftLoraLoaderMixinTests: scheduler_cls = None scheduler_kwargs = None has_two_text_encoders = False + has_three_text_encoders = False unet_kwargs = None + transformer_kwargs = None vae_kwargs = None def get_dummy_components(self, scheduler_cls=None, use_dora=False): + if self.unet_kwargs and self.transformer_kwargs: + raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") + if self.has_two_text_encoders and self.has_three_text_encoders: + raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") + scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls rank = 4 torch.manual_seed(0) - unet = UNet2DConditionModel(**self.unet_kwargs) + if self.unet_kwargs is not None: + unet = UNet2DConditionModel(**self.unet_kwargs) + else: + transformer = SD3Transformer2DModel(**self.transformer_kwargs) scheduler = scheduler_cls(**self.scheduler_kwargs) torch.manual_seed(0) vae = AutoencoderKL(**self.vae_kwargs) - text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") - tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + if not self.has_three_text_encoders: + text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") + tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") if self.has_two_text_encoders: text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2") tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + if self.has_three_text_encoders: + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder") + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2") + text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + text_lora_config = LoraConfig( r=rank, lora_alpha=rank, @@ -101,7 +122,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): use_dora=use_dora, ) - unet_lora_config = LoraConfig( + denoiser_lora_config = LoraConfig( r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], @@ -109,18 +130,31 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): use_dora=use_dora, ) - if self.has_two_text_encoders: - pipeline_components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - "image_encoder": None, - "feature_extractor": None, - } + if self.has_two_text_encoders or self.has_three_text_encoders: + if self.unet_kwargs is not None: + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, + } + elif self.has_three_text_encoders and self.transformer_kwargs is not None: + pipeline_components = { + "transformer": transformer, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "text_encoder_3": text_encoder_3, + "tokenizer_3": tokenizer_3, + } else: pipeline_components = { "unet": unet, @@ -133,7 +167,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): "image_encoder": None, } - return pipeline_components, text_lora_config, unet_lora_config + return pipeline_components, text_lora_config, denoiser_lora_config def get_dummy_inputs(self, with_generator=True): batch_size = 1 @@ -170,7 +204,12 @@ def test_simple_inference(self): """ Tests a simple inference and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -178,14 +217,20 @@ def test_simple_inference(self): _, _, inputs = self.get_dummy_inputs() output_no_lora = pipe(**inputs).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) def test_simple_inference_with_text_lora(self): """ Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -193,12 +238,13 @@ def test_simple_inference_with_text_lora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -214,7 +260,12 @@ def test_simple_inference_with_text_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -222,12 +273,13 @@ def test_simple_inference_with_text_lora_and_scale(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -238,17 +290,27 @@ def test_simple_inference_with_text_lora_and_scale(self): not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - ).images + if self.unet_kwargs is not None: + output_lora_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} + ).images + else: + output_lora_scale = pipe( + **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5} + ).images self.assertTrue( not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), "Lora + scale should change the output", ) - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} - ).images + if self.unet_kwargs is not None: + output_lora_0_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} + ).images + else: + output_lora_0_scale = pipe( + **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0} + ).images self.assertTrue( np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), "Lora + 0 scale should lead to same result as no LoRA", @@ -259,7 +321,12 @@ def test_simple_inference_with_text_lora_fused(self): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -267,12 +334,13 @@ def test_simple_inference_with_text_lora_fused(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -282,7 +350,7 @@ def test_simple_inference_with_text_lora_fused(self): # Fusing should still keep the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) @@ -297,7 +365,12 @@ def test_simple_inference_with_text_lora_unloaded(self): Tests a simple inference with lora attached to text encoder, then unloads the lora weights and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -305,12 +378,13 @@ def test_simple_inference_with_text_lora_unloaded(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -322,7 +396,7 @@ def test_simple_inference_with_text_lora_unloaded(self): check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" ) - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertFalse( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2", @@ -338,7 +412,12 @@ def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -346,12 +425,13 @@ def test_simple_inference_with_text_lora_save_load(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -361,7 +441,7 @@ def test_simple_inference_with_text_lora_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) self.pipeline_class.save_lora_weights( @@ -385,7 +465,7 @@ def test_simple_inference_with_text_lora_save_load(self): images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) @@ -401,7 +481,12 @@ def test_simple_inference_with_partial_text_lora(self): with different ranks and some adapters removed and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls) # Verify `LoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( @@ -418,7 +503,8 @@ def test_simple_inference_with_partial_text_lora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") @@ -430,7 +516,7 @@ def test_simple_inference_with_partial_text_lora(self): if "text_model.encoder.layers.4" not in module_name } - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -462,7 +548,12 @@ def test_simple_inference_save_pretrained(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -470,12 +561,13 @@ def test_simple_inference_save_pretrained(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -494,7 +586,7 @@ def test_simple_inference_save_pretrained(self): "Lora not correctly set in text encoder", ) - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertTrue( check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), "Lora not correctly set in text encoder 2", @@ -507,27 +599,42 @@ def test_simple_inference_save_pretrained(self): "Loading from saved checkpoints should give same results.", ) - def test_simple_inference_with_text_unet_lora_save_load(self): + def test_simple_inference_with_text_denoiser_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) - pipe.unet.add_adapter(unet_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -537,22 +644,36 @@ def test_simple_inference_with_text_unet_lora_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) - unet_state_dict = get_peft_model_state_dict(pipe.unet) - if self.has_two_text_encoders: + + if self.unet_kwargs is not None: + denoiser_state_dict = get_peft_model_state_dict(pipe.unet) + else: + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + + if self.has_two_text_encoders or self.has_three_text_encoders: text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - unet_lora_layers=unet_state_dict, - safe_serialization=False, - ) + if self.unet_kwargs is not None: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + unet_lora_layers=denoiser_state_dict, + safe_serialization=False, + ) + else: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + transformer_lora_layers=denoiser_state_dict, + safe_serialization=False, + ) else: self.pipeline_class.save_lora_weights( save_directory=tmpdirname, text_encoder_lora_layers=text_encoder_state_dict, - unet_lora_layers=unet_state_dict, + unet_lora_layers=denoiser_state_dict, safe_serialization=False, ) @@ -563,9 +684,10 @@ def test_simple_inference_with_text_unet_lora_save_load(self): images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) @@ -575,27 +697,37 @@ def test_simple_inference_with_text_unet_lora_save_load(self): "Loading from saved checkpoints should give same results.", ) - def test_simple_inference_with_text_unet_lora_and_scale(self): + def test_simple_inference_with_text_denoiser_lora_and_scale(self): """ Tests a simple inference with lora attached on the text encoder + Unet + scale argument and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) - pipe.unet.add_adapter(unet_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -606,17 +738,27 @@ def test_simple_inference_with_text_unet_lora_and_scale(self): not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - ).images + if self.unet_kwargs is not None: + output_lora_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} + ).images + else: + output_lora_scale = pipe( + **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5} + ).images self.assertTrue( not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), "Lora + scale should change the output", ) - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} - ).images + if self.unet_kwargs is not None: + output_lora_0_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} + ).images + else: + output_lora_0_scale = pipe( + **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0} + ).images self.assertTrue( np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), "Lora + 0 scale should lead to same result as no LoRA", @@ -627,28 +769,38 @@ def test_simple_inference_with_text_unet_lora_and_scale(self): "The scaling parameter has not been correctly restored!", ) - def test_simple_inference_with_text_lora_unet_fused(self): + def test_simple_inference_with_text_lora_denoiser_fused(self): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) - pipe.unet.add_adapter(unet_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -657,9 +809,10 @@ def test_simple_inference_with_text_lora_unet_fused(self): pipe.fuse_lora() # Fusing should still keep the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) @@ -669,27 +822,37 @@ def test_simple_inference_with_text_lora_unet_fused(self): np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) - def test_simple_inference_with_text_unet_lora_unloaded(self): + def test_simple_inference_with_text_denoiser_lora_unloaded(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) - pipe.unet.add_adapter(unet_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -700,9 +863,12 @@ def test_simple_inference_with_text_unet_lora_unloaded(self): self.assertFalse( check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" ) - self.assertFalse(check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertFalse( + check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly unloaded in denoiser" + ) - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertFalse( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2", @@ -714,25 +880,34 @@ def test_simple_inference_with_text_unet_lora_unloaded(self): "Fused lora should change the output", ) - def test_simple_inference_with_text_unet_lora_unfused(self): + def test_simple_inference_with_text_denoiser_lora_unfused(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config) - pipe.unet.add_adapter(unet_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -747,9 +922,10 @@ def test_simple_inference_with_text_unet_lora_unfused(self): output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images # unloading should remove the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" ) @@ -760,13 +936,18 @@ def test_simple_inference_with_text_unet_lora_unfused(self): "Fused lora should change the output", ) - def test_simple_inference_with_text_unet_multi_adapter(self): + def test_simple_inference_with_text_denoiser_multi_adapter(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -777,13 +958,20 @@ def test_simple_inference_with_text_unet_multi_adapter(self): pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") self.assertTrue( @@ -826,13 +1014,21 @@ def test_simple_inference_with_text_unet_multi_adapter(self): "output with no lora and output with lora disabled should give same results", ) - def test_simple_inference_with_text_unet_block_scale(self): + def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches one adapter and set differnt weights for different blocks (i.e. block lora) """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + return + + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -841,12 +1037,16 @@ def test_simple_inference_with_text_unet_block_scale(self): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -881,13 +1081,21 @@ def test_simple_inference_with_text_unet_block_scale(self): "output with no lora and output with lora disabled should give same results", ) - def test_simple_inference_with_text_unet_multi_adapter_block_lora(self): + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set differnt weights for different blocks (i.e. block lora) """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + return + + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -898,13 +1106,20 @@ def test_simple_inference_with_text_unet_multi_adapter_block_lora(self): pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") self.assertTrue( @@ -953,8 +1168,10 @@ def test_simple_inference_with_text_unet_multi_adapter_block_lora(self): with self.assertRaises(ValueError): pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) - def test_simple_inference_with_text_unet_block_scale_for_all_dict_options(self): + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" + if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + return def updown_options(blocks_with_tf, layers_per_block, value): """ @@ -1019,16 +1236,19 @@ def all_possible_dict_opts(unet, value): return opts - components, text_lora_config, unet_lora_config = self.get_dummy_components(self.scheduler_cls) + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): @@ -1038,13 +1258,18 @@ def all_possible_dict_opts(unet, value): pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error - def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self): + def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set/delete them """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -1055,13 +1280,20 @@ def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self): pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") self.assertTrue( @@ -1113,8 +1345,14 @@ def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self): pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"]) @@ -1126,13 +1364,18 @@ def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self): "output with no lora and output with lora disabled should give same results", ) - def test_simple_inference_with_text_unet_multi_adapter_weighted(self): + def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -1143,13 +1386,20 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self): pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") self.assertTrue( @@ -1202,8 +1452,13 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self): @skip_mps def test_lora_fuse_nan(self): - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -1211,16 +1466,23 @@ def test_lora_fuse_nan(self): pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") # corrupt one LoRA weight with `inf` values with torch.no_grad(): - pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float( - "inf" - ) + if self.unet_kwargs: + pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[ + "adapter-1" + ].weight += float("inf") + else: + pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") # with `safe_fusing=True` we should see an Error with self.assertRaises(ValueError): @@ -1238,21 +1500,32 @@ def test_get_adapters(self): Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") adapter_names = pipe.get_active_adapters() self.assertListEqual(adapter_names, ["adapter-1"]) pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") adapter_names = pipe.get_active_adapters() self.assertListEqual(adapter_names, ["adapter-2"]) @@ -1265,65 +1538,108 @@ def test_get_list_adapters(self): Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") adapter_names = pipe.get_list_adapters() - self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]}) + dicts_to_be_checked = {"text_encoder": ["adapter-1"]} + if self.unet_kwargs is not None: + dicts_to_be_checked.update({"unet": ["adapter-1"]}) + else: + dicts_to_be_checked.update({"transformer": ["adapter-1"]}) + self.assertDictEqual(adapter_names, dicts_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") adapter_names = pipe.get_list_adapters() - self.assertDictEqual( - adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]} - ) + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + if self.unet_kwargs is not None: + dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) + else: + dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) + self.assertDictEqual(adapter_names, dicts_to_be_checked) pipe.set_adapters(["adapter-1", "adapter-2"]) + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + if self.unet_kwargs is not None: + dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) + else: + dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) self.assertDictEqual( pipe.get_list_adapters(), - {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]}, + dicts_to_be_checked, ) - pipe.unet.add_adapter(unet_lora_config, "adapter-3") - self.assertDictEqual( - pipe.get_list_adapters(), - {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]}, - ) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") + + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + if self.unet_kwargs is not None: + dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) + else: + dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) + self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @require_peft_version_greater(peft_version="0.6.2") - def test_simple_inference_with_text_lora_unet_fused_multi(self): + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.unet.add_adapter(unet_lora_config, "adapter-1") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") # Attach a second adapter pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - pipe.unet.add_adapter(unet_lora_config, "adapter-2") + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") + else: + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") self.assertTrue( @@ -1359,23 +1675,35 @@ def test_simple_inference_with_text_lora_unet_fused_multi(self): @require_peft_version_greater(peft_version="0.9.0") def test_simple_inference_with_dora(self): - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + scheduler_cls, use_dora=True + ) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(output_no_dora_lora.shape == (1, 64, 64, 3)) + shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) + self.assertTrue(output_no_dora_lora.shape == shape_to_be_checked) pipe.text_encoder.add_adapter(text_lora_config) - pipe.unet.add_adapter(unet_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -1389,25 +1717,34 @@ def test_simple_inference_with_dora(self): ) @unittest.skip("This is failing for now - need to investigate") - def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): + def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.text_encoder.add_adapter(text_lora_config) - pipe.unet.add_adapter(unet_lora_config) + if self.unet_kwargs is not None: + pipe.unet.add_adapter(denoiser_lora_config) + else: + pipe.transformer.add_adapter(denoiser_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2.add_adapter(text_lora_config) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" @@ -1416,19 +1753,27 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) - if self.has_two_text_encoders: + if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) # Just makes sure it works.. _ = pipe(**inputs, generator=torch.manual_seed(0)).images def test_modify_padding_mode(self): + if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + return + def set_pad_mode(network, mode="circular"): for _, module in network.named_modules(): if isinstance(module, torch.nn.Conv2d): module.padding_mode = mode - for scheduler_cls in [DDIMScheduler, LCMScheduler]: + scheduler_classes = ( + [FlowMatchEulerDiscreteScheduler] + if self.has_three_text_encoders and self.transformer_kwargs + else [DDIMScheduler, LCMScheduler] + ) + for scheduler_cls in scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)