diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cc82cb09f969..25db36bb1c7a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -368,6 +368,7 @@ "FlaxDDIMScheduler", "FlaxDDPMScheduler", "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", "FlaxKarrasVeScheduler", "FlaxLMSDiscreteScheduler", "FlaxPNDMScheduler", @@ -395,6 +396,7 @@ "FlaxStableDiffusionImg2ImgPipeline", "FlaxStableDiffusionInpaintPipeline", "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", ] ) @@ -673,6 +675,7 @@ FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, @@ -691,6 +694,7 @@ FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, ) try: diff --git a/src/diffusers/models/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py index f9de83f87dab..4768e82dec4a 100644 --- a/src/diffusers/models/modeling_flax_pytorch_utils.py +++ b/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -42,9 +42,25 @@ def rename_key(key): # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" - # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + + # rename attention layers + if len(pt_tuple_key) > 1: + for rename_from, rename_to in ( + ("to_out_0", "proj_attn"), + ("to_k", "key"), + ("to_v", "value"), + ("to_q", "query"), + ): + if pt_tuple_key[-2] == rename_from: + weight_name = pt_tuple_key[-1] + weight_name = "kernel" if weight_name == "weight" else weight_name + renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) + if renamed_pt_tuple_key in random_flax_state_dict: + assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape + return renamed_pt_tuple_key, pt_tensor.T + if ( any("norm" in str_ for str_ in pt_tuple_key) and (pt_tuple_key[-1] == "bias") diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py index 4e4cebebe291..97f7b43bc64e 100644 --- a/src/diffusers/models/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -303,23 +303,23 @@ def from_pretrained( "framework": "flax", } - # Load config if we don't provide a configuration - config_path = config if config is not None else pretrained_model_name_or_path - model, model_kwargs = cls.from_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - # model args - dtype=dtype, - **kwargs, - ) + # Load config if we don't provide one + if config is None: + config, unused_kwargs = cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + **kwargs, + ) + + model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs) # Load model pretrained_path_with_subfolder = ( diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 0d1447570dda..8c24b9f264b0 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 def setup(self): resnets = [] @@ -72,7 +73,7 @@ def setup(self): in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, - depth=1, + depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, @@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 def setup(self): resnets = [] @@ -213,7 +215,7 @@ def setup(self): in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, - depth=1, + depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, @@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): use_linear_projection: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 def setup(self): # there is always at least one resnet @@ -350,7 +353,7 @@ def setup(self): in_channels=self.in_channels, n_heads=self.num_attention_heads, d_head=self.in_channels // self.num_attention_heads, - depth=1, + depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 60b66a7fdf04..385f0a42c598 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -883,7 +883,6 @@ def forward( time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index de39bc75d2e3..a511da1be318 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import flax import flax.linen as nn @@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False + transformer_layers_per_block: Union[int, Tuple[int]] = 1 + addition_embed_type: Optional[str] = None + addition_time_embed_dim: Optional[int] = None + addition_embed_type_num_heads: int = 64 + projection_class_embeddings_input_dim: Optional[int] = None def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors @@ -127,7 +132,17 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + added_cond_kwargs = None + if self.addition_embed_type == "text_time": + # TODO: how to get this from the config? It's no longer cross_attention_dim + text_embeds_dim = 1280 + time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim + time_ids_dims = time_ids_channels // self.addition_time_embed_dim + added_cond_kwargs = { + "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32), + "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32), + } + return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] def setup(self): block_out_channels = self.block_out_channels @@ -168,6 +183,24 @@ def setup(self): if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) + # transformer layers per block + transformer_layers_per_block = self.transformer_layers_per_block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) + + # addition embed types + if self.addition_embed_type is None: + self.add_embedding = None + elif self.addition_embed_type == "text_time": + if self.addition_time_embed_dim is None: + raise ValueError( + f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None" + ) + self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) + self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + else: + raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") + # down down_blocks = [] output_channel = block_out_channels[0] @@ -182,6 +215,7 @@ def setup(self): out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, @@ -207,6 +241,7 @@ def setup(self): in_channels=block_out_channels[-1], dropout=self.dropout, num_attention_heads=num_attention_heads[-1], + transformer_layers_per_block=transformer_layers_per_block[-1], use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, @@ -218,6 +253,7 @@ def setup(self): reversed_num_attention_heads = list(reversed(num_attention_heads)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] @@ -231,6 +267,7 @@ def setup(self): out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], num_attention_heads=reversed_num_attention_heads[i], add_upsample=not is_final_block, dropout=self.dropout, @@ -269,6 +306,7 @@ def __call__( sample, timesteps, encoder_hidden_states, + added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, down_block_additional_residuals=None, mid_block_additional_residual=None, return_dict: bool = True, @@ -300,6 +338,31 @@ def __call__( t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) + # additional embeddings + aug_emb = None + if self.addition_embed_type == "text_time": + if added_cond_kwargs is None: + raise ValueError( + f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if text_embeds is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + if time_ids is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + # compute time embeds + time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256) + time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1)) + add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1) + aug_emb = self.add_embedding(add_embeds) + + t_emb = t_emb + aug_emb if aug_emb is not None else t_emb + # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 94e849a68788..8a55cca5ab3f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,460 +1,471 @@ -from typing import TYPE_CHECKING - -from ..utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_k_diffusion_available, - is_librosa_available, - is_note_seq_available, - is_onnx_available, - is_torch_available, - is_transformers_available, -) - - -# These modules contain pipelines from multiple libraries/frameworks -_dummy_objects = {} -_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []} - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_pt_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) -else: - _import_structure["auto_pipeline"] = [ - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - ] - _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] - _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] - _import_structure["ddim"] = ["DDIMPipeline"] - _import_structure["ddpm"] = ["DDPMPipeline"] - _import_structure["dit"] = ["DiTPipeline"] - _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) - _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] - _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] - _import_structure["pndm"] = ["PNDMPipeline"] - _import_structure["repaint"] = ["RePaintPipeline"] - _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] - _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] -try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_librosa_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) -else: - _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] -try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] - _import_structure["audioldm"] = ["AudioLDMPipeline"] - _import_structure["audioldm2"] = [ - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - ] - _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] - _import_structure["controlnet"].extend( - [ - "BlipDiffusionControlNetPipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - ] - ) - _import_structure["deepfloyd_if"] = [ - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - ] - _import_structure["kandinsky"] = [ - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - ] - _import_structure["kandinsky2_2"] = [ - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - ] - _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) - _import_structure["musicldm"] = ["MusicLDMPipeline"] - _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] - _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] - _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] - _import_structure["stable_diffusion"].extend( - [ - "CLIPImageProjection", - "CycleDiffusionPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - ] - ) - _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] - _import_structure["stable_diffusion_xl"] = [ - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - ] - _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] - _import_structure["text_to_video_synthesis"] = [ - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "VideoToVideoSDPipeline", - ] - _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] - _import_structure["unidiffuser"] = [ - "ImageTextPipelineOutput", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - ] - _import_structure["versatile_diffusion"] = [ - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - ] - _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] - _import_structure["wuerstchen"] = [ - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] -try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_onnx_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) -else: - _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] -try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) -else: - _import_structure["stable_diffusion"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) -try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) -else: - _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"]) -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_flax_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) -else: - _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] -try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) - _import_structure["stable_diffusion"].extend( - [ - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - ] - ) -try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) -else: - _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] - -if TYPE_CHECKING: - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_pt_objects import * # noqa F403 - - else: - from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image - from .consistency_models import ConsistencyModelPipeline - from .dance_diffusion import DanceDiffusionPipeline - from .ddim import DDIMPipeline - from .ddpm import DDPMPipeline - from .dit import DiTPipeline - from .latent_diffusion import LDMSuperResolutionPipeline - from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput - from .pndm import PNDMPipeline - from .repaint import RePaintPipeline - from .score_sde_ve import ScoreSdeVePipeline - from .stochastic_karras_ve import KarrasVePipeline - - try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_librosa_objects import * - else: - from .audio_diffusion import AudioDiffusionPipeline, Mel - - try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_objects import * - else: - from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline - from .audioldm import AudioLDMPipeline - from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel - from .blip_diffusion import BlipDiffusionPipeline - from .controlnet import ( - BlipDiffusionControlNetPipeline, - StableDiffusionControlNetImg2ImgPipeline, - StableDiffusionControlNetInpaintPipeline, - StableDiffusionControlNetPipeline, - StableDiffusionXLControlNetImg2ImgPipeline, - StableDiffusionXLControlNetInpaintPipeline, - StableDiffusionXLControlNetPipeline, - ) - from .deepfloyd_if import ( - IFImg2ImgPipeline, - IFImg2ImgSuperResolutionPipeline, - IFInpaintingPipeline, - IFInpaintingSuperResolutionPipeline, - IFPipeline, - IFSuperResolutionPipeline, - ) - from .kandinsky import ( - KandinskyCombinedPipeline, - KandinskyImg2ImgCombinedPipeline, - KandinskyImg2ImgPipeline, - KandinskyInpaintCombinedPipeline, - KandinskyInpaintPipeline, - KandinskyPipeline, - KandinskyPriorPipeline, - ) - from .kandinsky2_2 import ( - KandinskyV22CombinedPipeline, - KandinskyV22ControlnetImg2ImgPipeline, - KandinskyV22ControlnetPipeline, - KandinskyV22Img2ImgCombinedPipeline, - KandinskyV22Img2ImgPipeline, - KandinskyV22InpaintCombinedPipeline, - KandinskyV22InpaintPipeline, - KandinskyV22Pipeline, - KandinskyV22PriorEmb2EmbPipeline, - KandinskyV22PriorPipeline, - ) - from .latent_diffusion import LDMTextToImagePipeline - from .musicldm import MusicLDMPipeline - from .paint_by_example import PaintByExamplePipeline - from .semantic_stable_diffusion import SemanticStableDiffusionPipeline - from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline - from .stable_diffusion import ( - CLIPImageProjection, - CycleDiffusionPipeline, - StableDiffusionAttendAndExcitePipeline, - StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, - StableDiffusionGLIGENPipeline, - StableDiffusionGLIGENTextImagePipeline, - StableDiffusionImageVariationPipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionInpaintPipelineLegacy, - StableDiffusionInstructPix2PixPipeline, - StableDiffusionLatentUpscalePipeline, - StableDiffusionLDM3DPipeline, - StableDiffusionModelEditingPipeline, - StableDiffusionPanoramaPipeline, - StableDiffusionParadigmsPipeline, - StableDiffusionPipeline, - StableDiffusionPix2PixZeroPipeline, - StableDiffusionSAGPipeline, - StableDiffusionUpscalePipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .stable_diffusion_xl import ( - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLPipeline, - ) - from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline - from .text_to_video_synthesis import ( - TextToVideoSDPipeline, - TextToVideoZeroPipeline, - VideoToVideoSDPipeline, - ) - from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline - from .unidiffuser import ( - ImageTextPipelineOutput, - UniDiffuserModel, - UniDiffuserPipeline, - UniDiffuserTextDecoder, - ) - from .versatile_diffusion import ( - VersatileDiffusionDualGuidedPipeline, - VersatileDiffusionImageVariationPipeline, - VersatileDiffusionPipeline, - VersatileDiffusionTextToImagePipeline, - ) - from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import ( - WuerstchenCombinedPipeline, - WuerstchenDecoderPipeline, - WuerstchenPriorPipeline, - ) - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_onnx_objects import * # noqa F403 - - else: - from .onnx_utils import OnnxRuntimeModel - - try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_onnx_objects import * - else: - from .stable_diffusion import ( - OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionInpaintPipeline, - OnnxStableDiffusionInpaintPipelineLegacy, - OnnxStableDiffusionPipeline, - OnnxStableDiffusionUpscalePipeline, - StableDiffusionOnnxPipeline, - ) - - try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * - else: - from .stable_diffusion import StableDiffusionKDiffusionPipeline - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_objects import * # noqa F403 - else: - from .pipeline_flax_utils import FlaxDiffusionPipeline - - try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_and_transformers_objects import * - else: - from .controlnet import FlaxStableDiffusionControlNetPipeline - from .stable_diffusion import ( - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - ) - - try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 - - else: - from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) +from typing import TYPE_CHECKING + +from ..utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_k_diffusion_available, + is_librosa_available, + is_note_seq_available, + is_onnx_available, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["auto_pipeline"] = [ + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + ] + _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] + _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] + _import_structure["ddim"] = ["DDIMPipeline"] + _import_structure["ddpm"] = ["DDPMPipeline"] + _import_structure["dit"] = ["DiTPipeline"] + _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) + _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] + _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] + _import_structure["pndm"] = ["PNDMPipeline"] + _import_structure["repaint"] = ["RePaintPipeline"] + _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] + _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_librosa_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) +else: + _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] + _import_structure["audioldm"] = ["AudioLDMPipeline"] + _import_structure["audioldm2"] = [ + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + ] + _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] + _import_structure["controlnet"].extend( + [ + "BlipDiffusionControlNetPipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + ] + ) + _import_structure["deepfloyd_if"] = [ + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + ] + _import_structure["kandinsky"] = [ + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + ] + _import_structure["kandinsky2_2"] = [ + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + ] + _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) + _import_structure["musicldm"] = ["MusicLDMPipeline"] + _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] + _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] + _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] + _import_structure["stable_diffusion"].extend( + [ + "CLIPImageProjection", + "CycleDiffusionPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + ] + ) + _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] + _import_structure["stable_diffusion_xl"].extend( + [ + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + ] + ) + _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] + _import_structure["text_to_video_synthesis"] = [ + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "VideoToVideoSDPipeline", + ] + _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] + _import_structure["unidiffuser"] = [ + "ImageTextPipelineOutput", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + ] + _import_structure["versatile_diffusion"] = [ + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + ] + _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] + _import_structure["wuerstchen"] = [ + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) +else: + _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) +else: + _import_structure["stable_diffusion"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) + +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) +else: + _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"]) +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) +else: + _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) + _import_structure["stable_diffusion"].extend( + [ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ] + ) + _import_structure["stable_diffusion_xl"].extend( + [ + "FlaxStableDiffusionXLPipeline", + ] + ) +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) +else: + _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] + +if TYPE_CHECKING: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + + else: + from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image + from .consistency_models import ConsistencyModelPipeline + from .dance_diffusion import DanceDiffusionPipeline + from .ddim import DDIMPipeline + from .ddpm import DDPMPipeline + from .dit import DiTPipeline + from .latent_diffusion import LDMSuperResolutionPipeline + from .latent_diffusion_uncond import LDMPipeline + from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput + from .pndm import PNDMPipeline + from .repaint import RePaintPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_librosa_objects import * + else: + from .audio_diffusion import AudioDiffusionPipeline, Mel + + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_objects import * + else: + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline + from .audioldm import AudioLDMPipeline + from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + from .blip_diffusion import BlipDiffusionPipeline + from .controlnet import ( + BlipDiffusionControlNetPipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetImg2ImgPipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLControlNetPipeline, + ) + from .deepfloyd_if import ( + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ) + from .kandinsky import ( + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyImg2ImgPipeline, + KandinskyInpaintCombinedPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + ) + from .kandinsky2_2 import ( + KandinskyV22CombinedPipeline, + KandinskyV22ControlnetImg2ImgPipeline, + KandinskyV22ControlnetPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintCombinedPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorEmb2EmbPipeline, + KandinskyV22PriorPipeline, + ) + from .latent_diffusion import LDMTextToImagePipeline + from .musicldm import MusicLDMPipeline + from .paint_by_example import PaintByExamplePipeline + from .semantic_stable_diffusion import SemanticStableDiffusionPipeline + from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline + from .stable_diffusion import ( + CLIPImageProjection, + CycleDiffusionPipeline, + StableDiffusionAttendAndExcitePipeline, + StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, + StableDiffusionGLIGENPipeline, + StableDiffusionGLIGENTextImagePipeline, + StableDiffusionImageVariationPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, + StableDiffusionLDM3DPipeline, + StableDiffusionModelEditingPipeline, + StableDiffusionPanoramaPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPipeline, + StableDiffusionPix2PixZeroPipeline, + StableDiffusionSAGPipeline, + StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPipeline, + ) + from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline + from .text_to_video_synthesis import ( + TextToVideoSDPipeline, + TextToVideoZeroPipeline, + VideoToVideoSDPipeline, + ) + from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ( + ImageTextPipelineOutput, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, + ) + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline + from .wuerstchen import ( + WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, + WuerstchenPriorPipeline, + ) + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_onnx_objects import * # noqa F403 + + else: + from .onnx_utils import OnnxRuntimeModel + + try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_onnx_objects import * + else: + from .stable_diffusion import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + OnnxStableDiffusionUpscalePipeline, + StableDiffusionOnnxPipeline, + ) + + try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * + else: + from .stable_diffusion import StableDiffusionKDiffusionPipeline + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 + else: + from .pipeline_flax_utils import FlaxDiffusionPipeline + + try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_and_transformers_objects import * + else: + from .controlnet import FlaxStableDiffusionControlNetPipeline + from .stable_diffusion import ( + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + ) + from .stable_diffusion_xl import ( + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + + else: + from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 23a7af1e1b82..d35e8ea26be4 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -394,10 +394,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - init_kwargs = {} + # define init kwargs + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + # Throw nice warnings / errors for fast accelerate loading + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) # inference_params params = {} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 2c4bf44f8dec..90dfef809bca 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -4,14 +4,18 @@ OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, + is_flax_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} +_additional_imports = {} _import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]} +if is_transformers_available() and is_flax_available(): + _import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"]) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() @@ -25,6 +29,12 @@ _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] +if is_transformers_available() and is_flax_available(): + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + + _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) + _import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"] + if TYPE_CHECKING: try: @@ -38,6 +48,17 @@ from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_objects import * + else: + from .pipeline_flax_stable_diffusion_xl import ( + FlaxStableDiffusionXLPipeline, + ) + from .pipeline_output import FlaxStableDiffusionXLPipelineOutput + else: import sys @@ -50,3 +71,5 @@ for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py new file mode 100644 index 000000000000..d561d67d4cc0 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -0,0 +1,306 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPTokenizer, FlaxCLIPTextModel + +from diffusers.utils import logging + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionXLPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + + +class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): + def __init__( + self, + text_encoder: FlaxCLIPTextModel, + text_encoder_2: FlaxCLIPTextModel, + vae: FlaxAutoencoderKL, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Assume we have the two encoders + inputs = [] + for tokenizer in [self.tokenizer, self.tokenizer_2]: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + inputs.append(text_inputs.input_ids) + inputs = jnp.stack(inputs, axis=1) + return inputs + + def __call__( + self, + prompt_ids: jax.Array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + guidance_scale: Union[float, jax.Array] = 7.5, + height: Optional[int] = None, + width: Optional[int] = None, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + return_dict: bool = True, + output_type: str = None, + jit: bool = False, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float) and jit: + # Convert to a tensor so each device gets a copy. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + guidance_scale = guidance_scale[:, None] + + return_latents = output_type == "latent" + + if jit: + images = _p_generate( + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) + else: + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) + + if not return_dict: + return (images,) + + return FlaxStableDiffusionXLPipelineOutput(images=images) + + def get_embeddings(self, prompt_ids: jnp.array, params): + # We assume we have the two encoders + + # bs, encoder_input, seq_length + te_1_inputs = prompt_ids[:, 0, :] + te_2_inputs = prompt_ids[:, 1, :] + + prompt_embeds = self.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True) + prompt_embeds = prompt_embeds["hidden_states"][-2] + prompt_embeds_2_out = self.text_encoder_2( + te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True + ) + prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] + text_embeds = prompt_embeds_2_out["text_embeds"] + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) + return prompt_embeds, text_embeds + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, bs, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype) + return add_time_ids + + def _generate( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + return_latents=False, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # Encode input prompt + prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params) + + # Get unconditional embeddings + batch_size = prompt_embeds.shape[0] + if neg_prompt_ids is None: + neg_prompt_ids = self.prepare_inputs([""] * batch_size) + + neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) + + add_time_ids = self._get_add_time_ids( + (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype + ) + + prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + + # Create random latents + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + # Prepare scheduler state + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # Denoising loop + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + if return_latents: + return latents + + # Decode latents + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + +# Static argnums are pipe, num_inference_steps, height, width, return_latents. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None), + static_broadcasted_argnums=(0, 4, 5, 6, 10), +) +def _p_generate( + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, +): + return pipe._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py index 0c9515da34ef..bff9071d203f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py @@ -4,7 +4,11 @@ import numpy as np import PIL -from ...utils import BaseOutput +from ...utils import ( + BaseOutput, + is_flax_available, + is_transformers_available, +) @dataclass @@ -19,3 +23,19 @@ class StableDiffusionXLPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + + +if is_transformers_available() and is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Flax Stable Diffusion XL pipelines. + + Args: + images (`np.ndarray`) + Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. + """ + + images: np.ndarray diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 703db3204e3d..99bf1d22ee91 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1094,7 +1094,6 @@ def forward( time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index bbd943e56cec..61f2a56d5dbe 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -76,6 +76,7 @@ _import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"] _import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"] + _import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"] _import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"] _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] diff --git a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py new file mode 100644 index 000000000000..179a0ceb470f --- /dev/null +++ b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py @@ -0,0 +1,265 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +@flax.struct.dataclass +class EulerDiscreteSchedulerState: + common: CommonSchedulerState + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + sigmas: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create( + cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + ): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + + +@dataclass +class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): + state: EulerDiscreteSchedulerState + + +class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 + + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) + + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + init_noise_sigma = sigmas.max() + else: + init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5 + + return EulerDiscreteSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) + + def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + state (`EulerDiscreteSchedulerState`): + the `FlaxEulerDiscreteScheduler` state data class instance. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + timestep (`int`): + current discrete timestep in the diffusion chain. + + Returns: + `jnp.ndarray`: scaled input sample + """ + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + sigma = state.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> EulerDiscreteSchedulerState: + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`EulerDiscreteSchedulerState`): + the `FlaxEulerDiscreteScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if self.config.timestep_spacing == "linspace": + timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // num_inference_steps + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += 1 + else: + raise ValueError( + f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}" + ) + + sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) + + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + init_noise_sigma = sigmas.max() + else: + init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5 + + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + num_inference_steps=num_inference_steps, + init_noise_sigma=init_noise_sigma, + ) + + def step( + self, + state: EulerDiscreteSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`EulerDiscreteSchedulerState`): + the `FlaxEulerDiscreteScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class + + Returns: + [`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + sigma = state.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + # dt = sigma_down - sigma + dt = state.sigmas[step_index + 1] - sigma + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample, state) + + return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: EulerDiscreteSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sigma = state.sigmas[timesteps].flatten() + sigma = broadcast_to_shape_from_left(sigma, noise.shape) + + noisy_samples = original_samples + noise * sigma + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index e2af382c8215..53d92ed33b9f 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -37,6 +37,7 @@ class FlaxKarrasDiffusionSchedulers(Enum): FlaxPNDMScheduler = 3 FlaxLMSDiscreteScheduler = 4 FlaxDPMSolverMultistepScheduler = 5 + FlaxEulerDiscreteScheduler = 6 @dataclass diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py index 162bac1c4331..5e65e5349bb0 100644 --- a/src/diffusers/utils/dummy_flax_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -60,3 +60,18 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 2bb80d136f33..5fa8dbc81931 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -122,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) +class FlaxEulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"]