Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor prepare_mask_and_masked_image with VaeImageProcessor #4444

Merged
merged 24 commits into from
Aug 25, 2023

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Aug 3, 2023

first attempt to refactor inpainting pipelines with VaeImageProcessor
lots of tests need to be added

This example work as expected after refactoring

import PIL
import requests
import torch
from io import BytesIO

from diffusers import StableDiffusionInpaintPipeline


def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")


img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))

generator = torch.Generator(device="cuda").manual_seed(0)

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")

prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
image.save("yellow_cat.png")

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 3, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 875 to 876
mask[mask < 0.5] = 0
mask[mask > 0.5] = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask[mask < 0.5] = 0
mask[mask > 0.5] = 1
mask[mask < 0.5] = 0
mask[mask > 0.5] = 1

shouldn't we do this in the preprocess function?

@@ -52,8 +52,16 @@ def __init__(
resample: str = "lanczos",
do_normalize: bool = True,
do_convert_rgb: bool = False,
do_convert_grayscale: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice idea!

):
super().__init__()
if do_convert_rgb and do_convert_grayscale:
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to throw an error here actually

@patrickvonplaten
Copy link
Contributor

Very nice first draft! cc @sayakpaul @pcuenca @williamberman here for a review as well

Comment on lines +60 to +64
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very descriptive! Looks good.

Comment on lines 157 to 164
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
the image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape [batch, height, width] or [batch, height, width, channel] if it is a pytorch tensor, should have
shape [batch, channel, height, width]
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maintain casing:

  • "the" -> "The" when it's being placed at the beginning.
  • Fullstops to end the sentences.

Comment on lines +205 to +207
image[image < 0.5] = 0
image[image >= 0.5] = 1
return image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should 0 and 1 be registered into the config vars?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think 0 and 1 is global enough to not have to be added to the config

Comment on lines +222 to +237
if isinstance(image, torch.Tensor):
# if image is a pytorch tensor could have 2 possible shapes:
# 1. batch x height x width: we should insert the channel dimension at position 1
# 2. channnel x height x width: we should insert batch dimension at position 0,
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
image = image.unsqueeze(1)
else:
# if it is a numpy array, it could have 2 possible shapes:
# 1. batch x height x width: insert channel dimension on last position
# 2. height x width x channel: insert batch dimension on first position
if image.shape[-1] == 1:
image = np.expand_dims(image, axis=0)
else:
image = np.expand_dims(image, axis=-1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For easier operability, does it make sense to first convert the input tensor to a NumPy array and then operate from there?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul

the output of preprocess is pytorch tensors, why would we want to convert to numpy array first?

the reason we want to this step here is because rest of our preprocessing logic assumes a 3D tensor has shape channel x height x width, which doesn't apply to grayscale images

for example if we have a tensor with shape (5, 256,256), if we try to process with current logic directly, we will do this:
(5, 256,256) -> put it into a list :[ (5,256,256)] -> torch.stack the list when not 4d: (1, 5, 256, 256)

the correct would be (5, 1, 256, 256)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why cannot we add a dummy channel to represent the grayscale images then? Just trying to understand it better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
yeah and that's what this section of code is doing:
Adding a dummy channel to represent either a missing channel or batch dimension to remove the ambiguity


mask = self.mask_processor.preprocess(mask_image, height=height, width=width)

masked_image = init_image * (mask < 0.5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Register 0.5 as a config var?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean of the init of the VAEProcessor or the vae model config ? IMO it's general enough to not have to be registered in a config

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking quite good.

Let's try to think of a non-exhaustive list of the scenarios we need to think of to ensure this refactor is robust and write test cases for each of them.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Aug 5, 2023

@patrickvonplaten
should we unify the default behavior for when the user does not pass the height and width arguments?

in SD inpaint pipeline, when height and width is None, it will default to sample_size * vae scale factor;
in controlnet inpaint, it will resize to the closest multiple of 8, which I think makes much more sense

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py#L866

)

# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if image.min() < 0:
if image.min() < 0 and do_normalize:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user configured the image_processor to have do_normalize=False, the expected range should be [-1,1] and we shouldn't send a warning (this is the case for controlnet image)

@@ -133,7 +133,11 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""

warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please use the deprecate function here:

def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):

Comment on lines 928 to 935
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,

might be worth to define a new type here actually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

PipelineImage = Union[
            torch.FloatTensor,
            PIL.Image.Image,
            np.ndarray,
            List[torch.FloatTensor],
            List[PIL.Image.Image],
            List[np.ndarray],
        ]

and use this

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks cool! Can we apply the changes also right away to SDXL and SDXL controlnet?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Aug 24, 2023

@patrickvonplaten

Can we apply the changes also right away to SDXL and SDXL controlnet?

updated SDXL-inpainting

SDXL controlnet does not have masks input, and we already used a image_processor to process the control condition:) we've added Image_processor to all the other SDXL pipelines wherever there is image input

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Aug 24, 2023

don't know why this test failed - I can't reproduce it on my machine; also this PR did not touch the paint_by_example pipeline at all
https://github.com/huggingface/diffusers/actions/runs/5959195493/job/16165075058?pr=4444#step:6:13213

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Let's merge it before we run into ugly merge conflicts :-)

@yiyixuxu yiyixuxu merged commit b7b1a30 into main Aug 25, 2023
@yiyixuxu yiyixuxu deleted the inpaint-preprocess branch August 25, 2023 18:18
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…gface#4444)

* refactor image processor for mask
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…gface#4444)

* refactor image processor for mask
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants