diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 00d8588d5a2a..c54af2ab2aa5 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to `np.ndarray` or `torch.Tensor`: The denormalized image array. """ - return (images / 2 + 0.5).clamp(0, 1) + return (images * 0.5 + 0.5).clamp(0, 1) @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: @@ -537,6 +537,27 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: return image + def _denormalize_conditionally( + self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None + ) -> torch.Tensor: + r""" + Denormalize a batch of images based on a condition list. + + Args: + images (`torch.Tensor`): + The input image tensor. + do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): + A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the + value of `do_normalize` in the `VaeImageProcessor` config. + """ + if do_denormalize is None: + return self.denormalize(images) if self.config.do_normalize else images + + # De-normalizing a batch and selectively torch.stack'ing the results turns out to be + # significantly faster than performing a lot of smaller denormalizations + denormalized = self.denormalize(images) + return torch.stack([denormalized[i] if do_denormalize[i] else images[i] for i in range(images.shape[0])]) + def get_default_height_width( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], @@ -752,12 +773,7 @@ def postprocess( if output_type == "latent": return image - if do_denormalize is None: - do_denormalize = [self.config.do_normalize] * image.shape[0] - - image = torch.stack( - [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] - ) + image = self._denormalize_conditionally(image, do_denormalize) if output_type == "pt": return image @@ -966,12 +982,7 @@ def postprocess( deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" - if do_denormalize is None: - do_denormalize = [self.config.do_normalize] * image.shape[0] - - image = torch.stack( - [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] - ) + image = self._denormalize_conditionally(image, do_denormalize) image = self.pt_to_numpy(image)