Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ def unload_ip_adapter(self):
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])

# remove feature extractor
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])
# remove feature extractor only when safety_checker is None as safety_checker uses
# the feature_extractor later
if not hasattr(self, "safety_checker"):
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])

# remove hidden encoder
self.unet.encoder_hid_proj = None
Expand Down
67 changes: 43 additions & 24 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,32 +427,38 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

image_embeds.append(single_image_embeds)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
Expand Down Expand Up @@ -620,6 +626,8 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
Expand Down Expand Up @@ -663,6 +671,11 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
Expand Down Expand Up @@ -882,6 +895,7 @@ def __call__(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -931,6 +945,9 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand Down Expand Up @@ -992,6 +1009,8 @@ def __call__(
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
)

Expand Down Expand Up @@ -1030,9 +1049,9 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

if ip_adapter_image is not None:
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_videos_per_prompt
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
)

# 4. Prepare timesteps
Expand Down
67 changes: 43 additions & 24 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,32 +507,38 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

image_embeds.append(single_image_embeds)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
Expand Down Expand Up @@ -588,6 +594,8 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
Expand Down Expand Up @@ -726,6 +734,11 @@ def check_inputs(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")

if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
Expand Down Expand Up @@ -910,6 +923,7 @@ def __call__(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -974,6 +988,9 @@ def __call__(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -1060,6 +1077,8 @@ def __call__(
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
Expand Down Expand Up @@ -1111,9 +1130,9 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

if ip_adapter_image is not None:
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
)

# 4. Prepare image
Expand Down
Loading