-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor prepare_mask_and_masked_image with VaeImageProcessor #4444
Conversation
The documentation is not available anymore as the PR was closed or merged. |
mask[mask < 0.5] = 0 | ||
mask[mask > 0.5] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask[mask < 0.5] = 0 | |
mask[mask > 0.5] = 1 | |
mask[mask < 0.5] = 0 | |
mask[mask > 0.5] = 1 |
shouldn't we do this in the preprocess function?
@@ -52,8 +52,16 @@ def __init__( | |||
resample: str = "lanczos", | |||
do_normalize: bool = True, | |||
do_convert_rgb: bool = False, | |||
do_convert_grayscale: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice idea!
src/diffusers/image_processor.py
Outdated
): | ||
super().__init__() | ||
if do_convert_rgb and do_convert_grayscale: | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe better to throw an error here actually
Very nice first draft! cc @sayakpaul @pcuenca @williamberman here for a review as well |
raise ValueError( | ||
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," | ||
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", | ||
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very descriptive! Looks good.
src/diffusers/image_processor.py
Outdated
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): | ||
the image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have | ||
shape [batch, height, width] or [batch, height, width, channel] if it is a pytorch tensor, should have | ||
shape [batch, channel, height, width] | ||
height (`int`, *optional*, defaults to `None`): | ||
The height in preprocessed image. If `None`, will use the height of `image` input | ||
width (`int`, *optional*`, defaults to `None`): | ||
The width in preprocessed. If `None`, will use the width of the `image` input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maintain casing:
- "the" -> "The" when it's being placed at the beginning.
- Fullstops to end the sentences.
image[image < 0.5] = 0 | ||
image[image >= 0.5] = 1 | ||
return image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should 0 and 1 be registered into the config vars?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think 0 and 1 is global enough to not have to be added to the config
if isinstance(image, torch.Tensor): | ||
# if image is a pytorch tensor could have 2 possible shapes: | ||
# 1. batch x height x width: we should insert the channel dimension at position 1 | ||
# 2. channnel x height x width: we should insert batch dimension at position 0, | ||
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1 | ||
# for simplicity, we insert a dimension of size 1 at position 1 for both cases | ||
image = image.unsqueeze(1) | ||
else: | ||
# if it is a numpy array, it could have 2 possible shapes: | ||
# 1. batch x height x width: insert channel dimension on last position | ||
# 2. height x width x channel: insert batch dimension on first position | ||
if image.shape[-1] == 1: | ||
image = np.expand_dims(image, axis=0) | ||
else: | ||
image = np.expand_dims(image, axis=-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For easier operability, does it make sense to first convert the input tensor to a NumPy array and then operate from there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the output of preprocess is pytorch tensors, why would we want to convert to numpy array first?
the reason we want to this step here is because rest of our preprocessing logic assumes a 3D tensor has shape channel x height x width
, which doesn't apply to grayscale images
for example if we have a tensor with shape (5, 256,256)
, if we try to process with current logic directly, we will do this:
(5, 256,256)
-> put it into a list :[ (5,256,256)]
-> torch.stack the list when not 4d: (1, 5, 256, 256)
the correct would be (5, 1, 256, 256)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why cannot we add a dummy channel to represent the grayscale images then? Just trying to understand it better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul
yeah and that's what this section of code is doing:
Adding a dummy channel to represent either a missing channel or batch dimension to remove the ambiguity
|
||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width) | ||
|
||
masked_image = init_image * (mask < 0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Register 0.5 as a config var?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean of the init of the VAEProcessor or the vae model config ? IMO it's general enough to not have to be registered in a config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking quite good.
Let's try to think of a non-exhaustive list of the scenarios we need to think of to ensure this refactor is robust and write test cases for each of them.
@patrickvonplaten in SD inpaint pipeline, when |
) | ||
|
||
# expected range [0,1], normalize to [-1,1] | ||
do_normalize = self.config.do_normalize | ||
if image.min() < 0: | ||
if image.min() < 0 and do_normalize: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if user configured the image_processor to have do_normalize=False
, the expected range should be [-1,1]
and we shouldn't send a warning (this is the case for controlnet image)
@@ -133,7 +133,11 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False | |||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 | |||
dimensions: ``batch x channels x height x width``. | |||
""" | |||
|
|||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's please use the deprecate
function here:
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): |
image: Union[ | ||
torch.FloatTensor, | ||
PIL.Image.Image, | ||
np.ndarray, | ||
List[torch.FloatTensor], | ||
List[PIL.Image.Image], | ||
List[np.ndarray], | ||
] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
image: Union[ | |
torch.FloatTensor, | |
PIL.Image.Image, | |
np.ndarray, | |
List[torch.FloatTensor], | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
] = None, | |
image: Union[ | |
torch.FloatTensor, | |
PIL.Image.Image, | |
np.ndarray, | |
List[torch.FloatTensor], | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
] = None, |
might be worth to define a new type here actually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like:
PipelineImage = Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
]
and use this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks cool! Can we apply the changes also right away to SDXL and SDXL controlnet?
updated SDXL-inpainting SDXL controlnet does not have masks input, and we already used a image_processor to process the control condition:) we've added Image_processor to all the other SDXL pipelines wherever there is image input |
don't know why this test failed - I can't reproduce it on my machine; also this PR did not touch the paint_by_example pipeline at all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Let's merge it before we run into ugly merge conflicts :-)
…gface#4444) * refactor image processor for mask --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
…gface#4444) * refactor image processor for mask --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
first attempt to refactor inpainting pipelines with VaeImageProcessor
lots of tests need to be added
This example work as expected after refactoring