From a7a15aabf3eaea32484c795d418a59d45d1b5e68 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 23 Sep 2023 11:04:12 +0000 Subject: [PATCH] Flax SDXL img2img --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + .../pipelines/stable_diffusion_xl/__init__.py | 2 + ...peline_flax_stable_diffusion_xl_img2img.py | 396 ++++++++++++++++++ .../dummy_flax_and_transformers_objects.py | 15 + 5 files changed, 417 insertions(+) create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl_img2img.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 25db36bb1c7a..a2415e4d8474 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -396,6 +396,7 @@ "FlaxStableDiffusionImg2ImgPipeline", "FlaxStableDiffusionInpaintPipeline", "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLImg2ImgPipeline", "FlaxStableDiffusionXLPipeline", ] ) @@ -694,6 +695,7 @@ FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLImg2ImgPipeline, FlaxStableDiffusionXLPipeline, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8a55cca5ab3f..4883767269e9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -239,6 +239,7 @@ ) _import_structure["stable_diffusion_xl"].extend( [ + "FlaxStableDiffusionXLImg2ImgPipeline", "FlaxStableDiffusionXLPipeline", ] ) @@ -446,6 +447,7 @@ FlaxStableDiffusionPipeline, ) from .stable_diffusion_xl import ( + FlaxStableDiffusionXLImg2ImgPipeline, FlaxStableDiffusionXLPipeline, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 90dfef809bca..4158943f7b13 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -34,6 +34,7 @@ _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) _import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"] + _import_structure["pipeline_flax_stable_diffusion_xl_img2img"] = ["FlaxStableDiffusionXLImg2ImgPipeline"] if TYPE_CHECKING: @@ -55,6 +56,7 @@ from ...utils.dummy_flax_objects import * else: from .pipeline_flax_stable_diffusion_xl import ( + FlaxStableDiffusionXLImg2ImgPipeline, FlaxStableDiffusionXLPipeline, ) from .pipeline_output import FlaxStableDiffusionXLPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl_img2img.py new file mode 100644 index 000000000000..b0236000ebe1 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl_img2img.py @@ -0,0 +1,396 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from PIL import Image +from transformers import CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionXLPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + + +class FlaxStableDiffusionXLImg2ImgPipeline(FlaxDiffusionPipeline): + # ignore_for_config = ["dtype", "requires_aesthetics_score"] + + def __init__( + self, + text_encoder_2: FlaxCLIPTextModel, + vae: FlaxAutoencoderKL, + tokenizer_2: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + requires_aesthetics_score: bool = False, + # force_zeros_for_empty_prompt: bool = True, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + # tokenizer, text_encoder are null in the refiner + self.register_modules( + vae=vae, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_text_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + input_ids = self.tokenizer_2( + prompt, + padding="max_length", + max_length=self.tokenizer_2.model_max_length, + truncation=True, + return_tensors="np", + ).input_ids + + # Introduce an axis for consistency with FlaxStableDiffusionXLPipeline + return input_ids[:, jnp.newaxis, :] + + def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + text_input_ids = self.prepare_text_inputs(prompt) + return text_input_ids, processed_images + + def __call__( + self, + prompt_ids: jax.Array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + strength: float = 0.3, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jax.Array] = 7.5, + noise: jnp.array = None, + neg_prompt_ids: jnp.array = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + return_dict: bool = True, + jit: bool = False, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float) and jit: + # Convert to a tensor so each device gets a copy. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + guidance_scale = guidance_scale[:, None] + + start_timestep = self.get_timestep_start(num_inference_steps, strength) + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + aesthetic_score, + negative_aesthetic_score, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + aesthetic_score, + negative_aesthetic_score, + ) + + if not return_dict: + return (images,) + + return FlaxStableDiffusionXLPipelineOutput(images=images) + + def get_embeddings(self, prompt_ids: jnp.array, params): + # bs, encoder_input, seq_length + te_inputs = prompt_ids[:, 0, :] + + prompt_embeds_2_out = self.text_encoder_2( + te_inputs, params=params["text_encoder_2"], output_hidden_states=True + ) + text_embeds = prompt_embeds_2_out["text_embeds"] + prompt_embeds = prompt_embeds_2_out["hidden_states"][-2] + return prompt_embeds, text_embeds + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, bs, dtype + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + # passed_add_embed_dim = ( + # self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + # ) + # TODO: verify (params["unet"]["add_embedding"]["linear_1"]["kernel"].shape[0] ?) + # expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + # if ( + # expected_add_embed_dim > passed_add_embed_dim + # and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + # ): + # raise ValueError( + # f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + # ) + # elif ( + # expected_add_embed_dim < passed_add_embed_dim + # and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + # ): + # raise ValueError( + # f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + # ) + # elif expected_add_embed_dim != passed_add_embed_dim: + # raise ValueError( + # f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + # ) + + add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype) + add_neg_time_ids = jnp.array([add_neg_time_ids] * bs, dtype=dtype) + + return add_time_ids, add_neg_time_ids + + def get_timestep_start(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + return t_start + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + start_timestep: int, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + noise: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # 1. Encode input prompt + prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params) + + # 2. Get unconditional embeddings + batch_size = prompt_embeds.shape[0] + if neg_prompt_ids is None: + neg_prompt_ids = self.prepare_text_inputs([""] * batch_size) + + neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + (height, width), + (0, 0), + (height, width), + aesthetic_score, + negative_aesthetic_score, + prompt_embeds.shape[0], + dtype=prompt_embeds.dtype, + ) + + prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_neg_time_ids, add_time_ids], axis=0) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if noise is None: + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if noise.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") + + if image.shape[1] == 4: + # Skip encoding if using latents as input + init_latents = image + else: + # Create init_latents + init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist + init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) + init_latents = self.vae.config.scaling_factor * init_latents + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) + latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # 6. Denoising loop + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + if DEBUG: + # run with python for loop + for i in range(start_timestep, num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) + + # 7. Decode latents + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + +# Static argnums are pipe, start_timestep, num_inference_steps, height, width, aesthetic_score, negative_aesthetic_score. +# A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 5, 6, 7, 8, 12, 13), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + aesthetic_score, + negative_aesthetic_score, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + aesthetic_score, + negative_aesthetic_score, + ) + + +def preprocess(image, dtype): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py index 5e65e5349bb0..8b37a597c830 100644 --- a/src/diffusers/utils/dummy_flax_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) +class FlaxStableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + class FlaxStableDiffusionXLPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"]