Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Union
from typing import Dict, Optional, Union

import torch
from huggingface_hub.utils import validate_hf_hub_args
Expand All @@ -34,6 +34,8 @@
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
LoRAIPAdapterAttnProcessor,
LoRAIPAdapterAttnProcessor2_0,
)

logger = logging.get_logger(__name__)
Expand All @@ -46,8 +48,8 @@ class IPAdapterMixin:
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
subfolder: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, can we leave as it was here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done as you suggested!

weight_name: str,
subfolder: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass this to the kwargs right? I don't think there's a need to expose this specifically. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass this to the kwargs right? I don't think there's a need to expose this specifically. WDYT?

Yes, we can pass it as keyword argument. The other IP Adapter models need the image encoder, stored in a subfolder of H94/IP-Adapter, while FaceID model doesn't require an image encoder, so I made it Optional. I will remove it in the next update

**kwargs,
):
"""
Expand Down Expand Up @@ -135,14 +137,15 @@ def load_ip_adapter(
# load CLIP image encoer here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
try:
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
except TypeError:
print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
except TypeError:
print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.")
try:
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
except TypeError:
print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.")

Let's try to not use try...except here please


# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
Expand All @@ -153,5 +156,13 @@ def load_ip_adapter(

def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if isinstance(
attn_processor,
(
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
LoRAIPAdapterAttnProcessor,
LoRAIPAdapterAttnProcessor2_0,
),
):
attn_processor.scale = scale
80 changes: 62 additions & 18 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,13 +684,20 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value

elif "proj.3.weight" in state_dict:
elif "proj.0.weight" in state_dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using this key a better option? Can we use a more resilient condition here to avoid side-effects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed to proj.0.weight because both IPAdapter Full and FaceID state dicts have it, while the IPAdapter Full proj.3.weight key is named norm.weight in the FaceID model

# IP-Adapter Full
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in
norm_layer = "proj.3.weight" if "proj.3.weight" in state_dict else "norm.weight"
cross_attention_dim = state_dict[norm_layer].shape[0]
num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim

image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim_in,
mult=multiplier,
num_tokens=num_tokens,
)

for key, value in state_dict.items():
Expand Down Expand Up @@ -744,14 +751,24 @@ def _load_ip_adapter_weights(self, state_dict):
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAIPAdapterAttnProcessor,
LoRAIPAdapterAttnProcessor2_0,
)

use_lora = False
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
elif "proj.3.weight" in state_dict["image_proj"]:
elif "proj.0.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
for k in state_dict["ip_adapter"].keys():
if "lora" in k:
num_image_text_embeds = 4
use_lora = True
break
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
Expand All @@ -774,20 +791,47 @@ def _load_ip_adapter_weights(self, state_dict):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
if use_lora:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=128,
).to(self.device, dtype=self.dtype)
else:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)
if use_lora:
attn_processor_class = (
LoRAIPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else LoRAIPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
rank=128,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make rank an argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for now, but perhaps new IP Adapter will be released in the future, using different LoRA ranks. Do you think it is better to remove it for now?

num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)

else:
attn_processor_class = (
IPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)

value_dict = {}
for k, w in attn_procs[name].state_dict().items():
Expand Down
Loading