-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Community Pipeline] IPAdapter FaceID #6276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
f291fad
d1f147a
40f626d
db6550a
6c29e66
21e90e5
a90bf76
e61e6f8
f4141ac
8c380d3
f2a952e
529e968
2c06ffa
217d9d0
b330166
73922cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -34,6 +34,8 @@ | |||||||||||||||||||||||||||||||||||||
| from ..models.attention_processor import ( | ||||||||||||||||||||||||||||||||||||||
| IPAdapterAttnProcessor, | ||||||||||||||||||||||||||||||||||||||
| IPAdapterAttnProcessor2_0, | ||||||||||||||||||||||||||||||||||||||
| LoRAIPAdapterAttnProcessor, | ||||||||||||||||||||||||||||||||||||||
| LoRAIPAdapterAttnProcessor2_0, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| logger = logging.get_logger(__name__) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -46,7 +48,6 @@ class IPAdapterMixin: | |||||||||||||||||||||||||||||||||||||
| def load_ip_adapter( | ||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||||||||||||||||||||||||||||||||||||||
| subfolder: str, | ||||||||||||||||||||||||||||||||||||||
| weight_name: str, | ||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -95,6 +96,7 @@ def load_ip_adapter( | |||||||||||||||||||||||||||||||||||||
| local_files_only = kwargs.pop("local_files_only", None) | ||||||||||||||||||||||||||||||||||||||
| token = kwargs.pop("token", None) | ||||||||||||||||||||||||||||||||||||||
| revision = kwargs.pop("revision", None) | ||||||||||||||||||||||||||||||||||||||
| subfolder = kwargs.pop("subfolder", None) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| user_agent = { | ||||||||||||||||||||||||||||||||||||||
| "file_type": "attn_procs_weights", | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -135,14 +137,15 @@ def load_ip_adapter( | |||||||||||||||||||||||||||||||||||||
| # load CLIP image encoer here if it has not been registered to the pipeline yet | ||||||||||||||||||||||||||||||||||||||
| if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: | ||||||||||||||||||||||||||||||||||||||
| if not isinstance(pretrained_model_name_or_path_or_dict, dict): | ||||||||||||||||||||||||||||||||||||||
| logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") | ||||||||||||||||||||||||||||||||||||||
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||||||||||||||||||||||||||||||||||||||
| pretrained_model_name_or_path_or_dict, | ||||||||||||||||||||||||||||||||||||||
| subfolder=os.path.join(subfolder, "image_encoder"), | ||||||||||||||||||||||||||||||||||||||
| ).to(self.device, dtype=self.dtype) | ||||||||||||||||||||||||||||||||||||||
| self.image_encoder = image_encoder | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| raise ValueError("`image_encoder` cannot be None when using IP Adapters.") | ||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") | ||||||||||||||||||||||||||||||||||||||
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||||||||||||||||||||||||||||||||||||||
| pretrained_model_name_or_path_or_dict, | ||||||||||||||||||||||||||||||||||||||
| subfolder=os.path.join(subfolder, "image_encoder"), | ||||||||||||||||||||||||||||||||||||||
| ).to(self.device, dtype=self.dtype) | ||||||||||||||||||||||||||||||||||||||
| self.image_encoder = image_encoder | ||||||||||||||||||||||||||||||||||||||
| except TypeError: | ||||||||||||||||||||||||||||||||||||||
| print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
| try: | |
| logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| pretrained_model_name_or_path_or_dict, | |
| subfolder=os.path.join(subfolder, "image_encoder"), | |
| ).to(self.device, dtype=self.dtype) | |
| self.image_encoder = image_encoder | |
| except TypeError: | |
| print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.") | |
| try: | |
| logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| pretrained_model_name_or_path_or_dict, | |
| subfolder=os.path.join(subfolder, "image_encoder"), | |
| ).to(self.device, dtype=self.dtype) | |
| self.image_encoder = image_encoder | |
| except TypeError: | |
| print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.") |
Let's try to not use try...except here please
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -707,13 +707,20 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): | |
| diffusers_name = key.replace("proj", "image_embeds") | ||
| updated_state_dict[diffusers_name] = value | ||
|
|
||
| elif "proj.3.weight" in state_dict: | ||
| elif "proj.0.weight" in state_dict: | ||
|
||
| # IP-Adapter Full | ||
| clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] | ||
| cross_attention_dim = state_dict["proj.3.weight"].shape[0] | ||
| clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] | ||
| clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] | ||
| multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in | ||
| norm_layer = "proj.3.weight" if "proj.3.weight" in state_dict else "norm.weight" | ||
| cross_attention_dim = state_dict[norm_layer].shape[0] | ||
| num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim | ||
|
|
||
| image_projection = IPAdapterFullImageProjection( | ||
| cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim | ||
| cross_attention_dim=cross_attention_dim, | ||
| image_embed_dim=clip_embeddings_dim_in, | ||
| mult=multiplier, | ||
| num_tokens=num_tokens, | ||
| ) | ||
|
|
||
| for key, value in state_dict.items(): | ||
|
|
@@ -767,14 +774,24 @@ def _load_ip_adapter_weights(self, state_dict): | |
| AttnProcessor2_0, | ||
| IPAdapterAttnProcessor, | ||
| IPAdapterAttnProcessor2_0, | ||
| LoRAAttnProcessor, | ||
| LoRAAttnProcessor2_0, | ||
| LoRAIPAdapterAttnProcessor, | ||
| LoRAIPAdapterAttnProcessor2_0, | ||
| ) | ||
|
|
||
| use_lora = False | ||
| if "proj.weight" in state_dict["image_proj"]: | ||
| # IP-Adapter | ||
| num_image_text_embeds = 4 | ||
| elif "proj.3.weight" in state_dict["image_proj"]: | ||
| elif "proj.0.weight" in state_dict["image_proj"]: | ||
| # IP-Adapter Full Face | ||
| num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token | ||
| for k in state_dict["ip_adapter"].keys(): | ||
| if "lora" in k: | ||
| num_image_text_embeds = 4 | ||
| use_lora = True | ||
| break | ||
| else: | ||
| # IP-Adapter Plus | ||
| num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] | ||
|
|
@@ -797,20 +814,47 @@ def _load_ip_adapter_weights(self, state_dict): | |
| block_id = int(name[len("down_blocks.")]) | ||
| hidden_size = self.config.block_out_channels[block_id] | ||
| if cross_attention_dim is None or "motion_modules" in name: | ||
| attn_processor_class = ( | ||
| AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class() | ||
| if use_lora: | ||
| attn_processor_class = ( | ||
| LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class( | ||
| hidden_size=hidden_size, | ||
| cross_attention_dim=cross_attention_dim, | ||
| rank=128, | ||
| ).to(self.device, dtype=self.dtype) | ||
| else: | ||
| attn_processor_class = ( | ||
| AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class() | ||
| else: | ||
| attn_processor_class = ( | ||
| IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class( | ||
| hidden_size=hidden_size, | ||
| cross_attention_dim=cross_attention_dim, | ||
| scale=1.0, | ||
| num_tokens=num_image_text_embeds, | ||
| ).to(dtype=self.dtype, device=self.device) | ||
| if use_lora: | ||
| attn_processor_class = ( | ||
| LoRAIPAdapterAttnProcessor2_0 | ||
| if hasattr(F, "scaled_dot_product_attention") | ||
| else LoRAIPAdapterAttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class( | ||
| hidden_size=hidden_size, | ||
| cross_attention_dim=cross_attention_dim, | ||
| scale=1.0, | ||
| rank=128, | ||
|
||
| num_tokens=num_image_text_embeds, | ||
| ).to(dtype=self.dtype, device=self.device) | ||
|
|
||
| else: | ||
| attn_processor_class = ( | ||
| IPAdapterAttnProcessor2_0 | ||
| if hasattr(F, "scaled_dot_product_attention") | ||
| else IPAdapterAttnProcessor | ||
| ) | ||
| attn_procs[name] = attn_processor_class( | ||
| hidden_size=hidden_size, | ||
| cross_attention_dim=cross_attention_dim, | ||
| scale=1.0, | ||
| num_tokens=num_image_text_embeds, | ||
| ).to(dtype=self.dtype, device=self.device) | ||
|
|
||
| value_dict = {} | ||
| for k, w in attn_procs[name].state_dict().items(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a breaking change, can we leave as it was here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done as you suggested!