Skip to content

Commit 2e8d18e

Browse files
yiyixuxuyiyixuxuasomozapatrickvonplaten
authored
[IP-Adapter] Support multiple IP-Adapters (#6573)
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Alvaro Somoza <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 03373de commit 2e8d18e

25 files changed

+895
-235
lines changed

docs/source/en/using-diffusers/loading_adapters.md

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -506,22 +506,11 @@ import torch
506506
from diffusers import StableDiffusionPipeline, DDIMScheduler
507507
from diffusers.utils import load_image
508508

509-
noise_scheduler = DDIMScheduler(
510-
num_train_timesteps=1000,
511-
beta_start=0.00085,
512-
beta_end=0.012,
513-
beta_schedule="scaled_linear",
514-
clip_sample=False,
515-
set_alpha_to_one=False,
516-
steps_offset=1
517-
)
518-
519509
pipeline = StableDiffusionPipeline.from_pretrained(
520510
"runwayml/stable-diffusion-v1-5",
521511
torch_dtype=torch.float16,
522-
scheduler=noise_scheduler,
523512
).to("cuda")
524-
513+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
525514
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
526515

527516
pipeline.set_ip_adapter_scale(0.7)
@@ -550,6 +539,66 @@ image = pipeline(
550539
</div>
551540
</div>
552541

542+
543+
You can load multiple IP-Adapter models and use multiple reference images at the same time. In this example we use IP-Adapter-Plus face model to create a consistent character and also use IP-Adapter-Plus model along with 10 images to create a coherent style in the image we generate.
544+
545+
```python
546+
import torch
547+
from diffusers import AutoPipelineForText2Image, DDIMScheduler
548+
from transformers import CLIPVisionModelWithProjection
549+
from diffusers.utils import load_image
550+
551+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
552+
"h94/IP-Adapter",
553+
subfolder="models/image_encoder",
554+
torch_dtype=torch.float16,
555+
)
556+
557+
pipeline = AutoPipelineForText2Image.from_pretrained(
558+
"stabilityai/stable-diffusion-xl-base-1.0",
559+
torch_dtype=torch.float16,
560+
image_encoder=image_encoder,
561+
)
562+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
563+
pipeline.load_ip_adapter(
564+
"h94/IP-Adapter",
565+
subfolder="sdxl_models",
566+
weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"]
567+
)
568+
pipeline.set_ip_adapter_scale([0.7, 0.3])
569+
pipeline.enable_model_cpu_offload()
570+
571+
face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")
572+
style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
573+
style_images = [load_image(f"{style_folder}/img{i}.png") for i in range(10)]
574+
575+
generator = torch.Generator(device="cpu").manual_seed(0)
576+
577+
image = pipeline(
578+
prompt="wonderwoman",
579+
ip_adapter_image=[style_images, face_image],
580+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
581+
num_inference_steps=50, num_images_per_prompt=1,
582+
generator=generator,
583+
).images[0]
584+
```
585+
<div class="flex justify-center">
586+
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_style_grid.png" />
587+
<figcaption class="mt-2 text-center text-sm text-gray-500">style input image</figcaption>
588+
</div>
589+
590+
<div class="flex flex-row gap-4">
591+
<div class="flex-1">
592+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png"/>
593+
<figcaption class="mt-2 text-center text-sm text-gray-500">face input image</figcaption>
594+
</div>
595+
<div class="flex-1">
596+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_multi_out.png"/>
597+
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
598+
</div>
599+
</div>
600+
601+
553602
### LCM-Lora
554603

555604
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.

src/diffusers/loaders/ip_adapter.py

Lines changed: 81 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from pathlib import Path
15-
from typing import Dict, Union
16+
from typing import Dict, List, Union
1617

1718
import torch
1819
from huggingface_hub.utils import validate_hf_hub_args
@@ -45,9 +46,9 @@ class IPAdapterMixin:
4546
@validate_hf_hub_args
4647
def load_ip_adapter(
4748
self,
48-
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
49-
subfolder: str,
50-
weight_name: str,
49+
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
50+
subfolder: Union[str, List[str]],
51+
weight_name: Union[str, List[str]],
5152
**kwargs,
5253
):
5354
"""
@@ -87,6 +88,26 @@ def load_ip_adapter(
8788
The subfolder location of a model file within a larger model repository on the Hub or locally.
8889
"""
8990

91+
# handle the list inputs for multiple IP Adapters
92+
if not isinstance(weight_name, list):
93+
weight_name = [weight_name]
94+
95+
if not isinstance(pretrained_model_name_or_path_or_dict, list):
96+
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
97+
if len(pretrained_model_name_or_path_or_dict) == 1:
98+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
99+
100+
if not isinstance(subfolder, list):
101+
subfolder = [subfolder]
102+
if len(subfolder) == 1:
103+
subfolder = subfolder * len(weight_name)
104+
105+
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
106+
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
107+
108+
if len(weight_name) != len(subfolder):
109+
raise ValueError("`weight_name` and `subfolder` must have the same length.")
110+
90111
# Load the main state dict first.
91112
cache_dir = kwargs.pop("cache_dir", None)
92113
force_download = kwargs.pop("force_download", False)
@@ -100,61 +121,68 @@ def load_ip_adapter(
100121
"file_type": "attn_procs_weights",
101122
"framework": "pytorch",
102123
}
103-
104-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
105-
model_file = _get_model_file(
106-
pretrained_model_name_or_path_or_dict,
107-
weights_name=weight_name,
108-
cache_dir=cache_dir,
109-
force_download=force_download,
110-
resume_download=resume_download,
111-
proxies=proxies,
112-
local_files_only=local_files_only,
113-
token=token,
114-
revision=revision,
115-
subfolder=subfolder,
116-
user_agent=user_agent,
117-
)
118-
if weight_name.endswith(".safetensors"):
119-
state_dict = {"image_proj": {}, "ip_adapter": {}}
120-
with safe_open(model_file, framework="pt", device="cpu") as f:
121-
for key in f.keys():
122-
if key.startswith("image_proj."):
123-
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
124-
elif key.startswith("ip_adapter."):
125-
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
126-
else:
127-
state_dict = torch.load(model_file, map_location="cpu")
128-
else:
129-
state_dict = pretrained_model_name_or_path_or_dict
130-
131-
keys = list(state_dict.keys())
132-
if keys != ["image_proj", "ip_adapter"]:
133-
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
134-
135-
# load CLIP image encoder here if it has not been registered to the pipeline yet
136-
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
124+
state_dicts = []
125+
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
126+
pretrained_model_name_or_path_or_dict, weight_name, subfolder
127+
):
137128
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
138-
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
139-
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
129+
model_file = _get_model_file(
140130
pretrained_model_name_or_path_or_dict,
141-
subfolder=Path(subfolder, "image_encoder").as_posix(),
142-
).to(self.device, dtype=self.dtype)
143-
self.image_encoder = image_encoder
144-
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
131+
weights_name=weight_name,
132+
cache_dir=cache_dir,
133+
force_download=force_download,
134+
resume_download=resume_download,
135+
proxies=proxies,
136+
local_files_only=local_files_only,
137+
token=token,
138+
revision=revision,
139+
subfolder=subfolder,
140+
user_agent=user_agent,
141+
)
142+
if weight_name.endswith(".safetensors"):
143+
state_dict = {"image_proj": {}, "ip_adapter": {}}
144+
with safe_open(model_file, framework="pt", device="cpu") as f:
145+
for key in f.keys():
146+
if key.startswith("image_proj."):
147+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
148+
elif key.startswith("ip_adapter."):
149+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
150+
else:
151+
state_dict = torch.load(model_file, map_location="cpu")
145152
else:
146-
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
147-
148-
# create feature extractor if it has not been registered to the pipeline yet
149-
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
150-
self.feature_extractor = CLIPImageProcessor()
151-
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])
152-
153-
# load ip-adapter into unet
153+
state_dict = pretrained_model_name_or_path_or_dict
154+
155+
keys = list(state_dict.keys())
156+
if keys != ["image_proj", "ip_adapter"]:
157+
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
158+
159+
state_dicts.append(state_dict)
160+
161+
# load CLIP image encoder here if it has not been registered to the pipeline yet
162+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
163+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
164+
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
165+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
166+
pretrained_model_name_or_path_or_dict,
167+
subfolder=Path(subfolder, "image_encoder").as_posix(),
168+
).to(self.device, dtype=self.dtype)
169+
self.image_encoder = image_encoder
170+
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
171+
else:
172+
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
173+
174+
# create feature extractor if it has not been registered to the pipeline yet
175+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
176+
feature_extractor = CLIPImageProcessor()
177+
self.register_modules(feature_extractor=feature_extractor)
178+
179+
# load ip-adapter into unet
154180
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
155-
unet._load_ip_adapter_weights(state_dict)
181+
unet._load_ip_adapter_weights(state_dicts)
156182

157183
def set_ip_adapter_scale(self, scale):
184+
if not isinstance(scale, list):
185+
scale = [scale]
158186
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
159187
for attn_processor in unet.attn_processors.values():
160188
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):

src/diffusers/loaders/unet.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
from huggingface_hub.utils import validate_hf_hub_args
2626
from torch import nn
2727

28-
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
28+
from ..models.embeddings import (
29+
ImageProjection,
30+
IPAdapterFullImageProjection,
31+
IPAdapterPlusImageProjection,
32+
MultiIPAdapterImageProjection,
33+
)
2934
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
3035
from ..utils import (
3136
USE_PEFT_BACKEND,
@@ -763,28 +768,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
763768
image_projection.load_state_dict(updated_state_dict)
764769
return image_projection
765770

766-
def _load_ip_adapter_weights(self, state_dict):
771+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
767772
from ..models.attention_processor import (
768773
AttnProcessor,
769774
AttnProcessor2_0,
770775
IPAdapterAttnProcessor,
771776
IPAdapterAttnProcessor2_0,
772777
)
773778

774-
if "proj.weight" in state_dict["image_proj"]:
775-
# IP-Adapter
776-
num_image_text_embeds = 4
777-
elif "proj.3.weight" in state_dict["image_proj"]:
778-
# IP-Adapter Full Face
779-
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
780-
else:
781-
# IP-Adapter Plus
782-
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
783-
784-
# Set encoder_hid_proj after loading ip_adapter weights,
785-
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
786-
self.encoder_hid_proj = None
787-
788779
# set ip-adapter cross-attention processors & load state_dict
789780
attn_procs = {}
790781
key_id = 1
@@ -798,6 +789,7 @@ def _load_ip_adapter_weights(self, state_dict):
798789
elif name.startswith("down_blocks"):
799790
block_id = int(name[len("down_blocks.")])
800791
hidden_size = self.config.block_out_channels[block_id]
792+
801793
if cross_attention_dim is None or "motion_modules" in name:
802794
attn_processor_class = (
803795
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
@@ -807,6 +799,18 @@ def _load_ip_adapter_weights(self, state_dict):
807799
attn_processor_class = (
808800
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
809801
)
802+
num_image_text_embeds = []
803+
for state_dict in state_dicts:
804+
if "proj.weight" in state_dict["image_proj"]:
805+
# IP-Adapter
806+
num_image_text_embeds += [4]
807+
elif "proj.3.weight" in state_dict["image_proj"]:
808+
# IP-Adapter Full Face
809+
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
810+
else:
811+
# IP-Adapter Plus
812+
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
813+
810814
attn_procs[name] = attn_processor_class(
811815
hidden_size=hidden_size,
812816
cross_attention_dim=cross_attention_dim,
@@ -815,16 +819,31 @@ def _load_ip_adapter_weights(self, state_dict):
815819
).to(dtype=self.dtype, device=self.device)
816820

817821
value_dict = {}
818-
for k, w in attn_procs[name].state_dict().items():
819-
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
822+
for i, state_dict in enumerate(state_dicts):
823+
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
824+
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
820825

821826
attn_procs[name].load_state_dict(value_dict)
822827
key_id += 2
823828

829+
return attn_procs
830+
831+
def _load_ip_adapter_weights(self, state_dicts):
832+
if not isinstance(state_dicts, list):
833+
state_dicts = [state_dicts]
834+
# Set encoder_hid_proj after loading ip_adapter weights,
835+
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
836+
self.encoder_hid_proj = None
837+
838+
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
824839
self.set_attn_processor(attn_procs)
825840

826841
# convert IP-Adapter Image Projection layers to diffusers
827-
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
842+
image_projection_layers = []
843+
for state_dict in state_dicts:
844+
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
845+
image_projection_layer.to(device=self.device, dtype=self.dtype)
846+
image_projection_layers.append(image_projection_layer)
828847

829-
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
848+
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
830849
self.config.encoder_hid_dim_type = "ip_image_proj"

0 commit comments

Comments
 (0)