Skip to content

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

As per the discussion of #6830.

Testing script:

from diffusers import AutoPipelineForText2Image
from diffusers.models import ImageProjection
import torch
from diffusers.utils import load_image


def encode_image(image_encoder, feature_extractor, image, device, num_images_per_prompt, output_hidden_states=None):
    dtype = next(image_encoder.parameters()).dtype

    if not isinstance(image, torch.Tensor):
        image = feature_extractor(image, return_tensors="pt").pixel_values

    image = image.to(device=device, dtype=dtype)
    if output_hidden_states:
        image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
        image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
        uncond_image_enc_hidden_states = image_encoder(
            torch.zeros_like(image), output_hidden_states=True
        ).hidden_states[-2]
        uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
        return image_enc_hidden_states, uncond_image_enc_hidden_states
    else:
        image_embeds = image_encoder(image).image_embeds
        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
        uncond_image_embeds = torch.zeros_like(image_embeds)

        return image_embeds, uncond_image_embeds


@torch.no_grad()
def prepare_ip_adapter_image_embeds(
    unet,
    image_encoder,
    feature_extractor,
    ip_adapter_image,
    do_classifier_free_guidance,
    device,
    num_images_per_prompt,
):
    if not isinstance(ip_adapter_image, list):
        ip_adapter_image = [ip_adapter_image]

    if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
        raise ValueError(
            f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
        )

    image_embeds = []
    for single_ip_adapter_image, image_proj_layer in zip(
        ip_adapter_image, unet.encoder_hid_proj.image_projection_layers
    ):
        output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
        single_image_embeds, single_negative_image_embeds = encode_image(
            image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
        )
        single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
        single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)

        if do_classifier_free_guidance:
            single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
            single_image_embeds = single_image_embeds.to(device)

        image_embeds.append(single_image_embeds)

    return image_embeds


pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
    "cuda"
)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipeline.set_ip_adapter_scale(0.6)

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
)
image_embeds = prepare_ip_adapter_image_embeds(
    unet=pipeline.unet,
    image_encoder=pipeline.image_encoder,
    feature_extractor=pipeline.feature_extractor,
    ip_adapter_image=image,
    do_classifier_free_guidance=True,
    device="cuda",
    num_images_per_prompt=1,
)
pipeline.unload_ip_adapter()


generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt="best quality, high quality, wearing sunglasses",
    ip_adapter_image_embeds=image_embeds,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    num_inference_steps=50,
    generator=generator,
).images
images[0].save("embeds_out.png")

Here's our cute bear:

image

We could introduce static methods namely _encode_ip_adapter_image() and _prepare_ip_adapter_image_embeds and delegate the current calls of encode_image() and prepare_ip_adapter_image_embeds() to them, respectively. This way, it should be possible for users to not to code encode_image() and prepare_ip_adapter_image_embeds() explicitly like shown above.

So the flow would be like:

from diffusers import StableDiffisionPipeline
from diffusers.models import ImageProjection
import torch
from diffusers.utils import load_image


pipeline = StableDiffisionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
    "cuda"
)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipeline.set_ip_adapter_scale(0.6)

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
)
image_embeds = pipeline._prepare_ip_adapter_image_embeds(
    unet=pipeline.unet,
    image_encoder=pipeline.image_encoder,
    feature_extractor=pipeline.feature_extractor,
    ip_adapter_image=image,
    do_classifier_free_guidance=True,
    device="cuda",
    num_images_per_prompt=1,
)
pipeline.unload_ip_adapter()


generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt="best quality, high quality, wearing sunglasses",
    ip_adapter_image_embeds=image_embeds,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    num_inference_steps=50,
    generator=generator,
).images
images[0].save("embeds_out.png")

@sayakpaul sayakpaul requested a review from yiyixuxu February 6, 2024 08:11
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu mentioned this pull request Feb 6, 2024
6 tasks
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for adding this so quickly:) it's looking great!
I left a comment for the unload_ip_adapter portion

@asomoza
Copy link
Member

asomoza commented Feb 7, 2024

thank you @sayakpaul, so as I understand it, we need to pass the image embeddings for each IP Adapter which is cool, to be able to mix between an image, list of images or embeddings for each IP Adapter is exactly what I was looking for.

the only other issue left in my list but maybe we can discuss it later and not in this PR, is that diffusers is the only library/app that does the resampling in the forward of the unet instead of when getting the embeddings for the images, this would prevent that we can use embeddings from other apps or libraries and vice versa.

ComfyUI
https://github.com/cubiq/Diffusers_IPAdapter/blob/169add95683d5be6696975375fb70c50795a7947/ip_adapter/ip_adapter.py#L108

InvokeAI
https://github.com/invoke-ai/InvokeAI/blob/79ae9c4e64cb4f64d54c25b1c487501752b8fa84/invokeai/backend/ip_adapter/ip_adapter.py#L138

It would be good to do it here but I remember @yiyixuxu telling me that you were thinking of taking the image projection outside of the unet so maybe we can discuss it then.

@sayakpaul
Copy link
Member Author

the only other issue left in my list but maybe we can discuss it later and not in this PR, is that diffusers is the only library/app that does the resampling in the forward of the unet instead of when getting the embeddings for the images, this would prevent that we can use embeddings from other apps or libraries and vice versa.

I don't understand it. What's resampling in this context? Prefer taking references to the diffusers codebase here.

@asomoza
Copy link
Member

asomoza commented Feb 7, 2024

Oh sorry, I meant the Image projection, is what is done here in diffusers:

image_embeds = self.encoder_hid_proj(image_embeds)

@sayakpaul
Copy link
Member Author

Oh in that case, that warrants a separate PR / discussion.

@sayakpaul
Copy link
Member Author

@yiyixuxu @DN6 I think this is ready for another review.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 7, 2024

@asomoza
yes please open another issue:)

The short answer is:

  1. we will not just separate the image projection layer from unet just for ip-adapter - but this is something we are considering for our unet refactor
  2. I think we do not need to separate the image_projection layers in order to accommodate what you need:)

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me:) thank you
can we look into that failing test and make sure it is unrelated here?

@sayakpaul sayakpaul merged commit aa82df5 into main Feb 8, 2024
@sayakpaul sayakpaul deleted the feat-ip-image-embeddings branch February 8, 2024 05:40
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…call (huggingface#6868)

* add: support for passing ip adapter image embeddings

* debugging

* make feature_extractor unloading conditioned on safety_checker

* better condition

* type annotation

* index to look into value slices

* more debugging

* debugging

* serialize embeddings dict

* better conditioning

* remove unnecessary prints.

* Update src/diffusers/loaders/ip_adapter.py

Co-authored-by: YiYi Xu <[email protected]>

* make fix-copies and styling.

* styling and further copy fixing.

* fix: check_inputs call in controlnet sdxl img2img pipeline

---------

Co-authored-by: YiYi Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants