From cbdd6d6c03c45ab1b2cb31348b7f952c6bf827af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20M=C3=BCller?= Date: Tue, 11 Jul 2023 21:47:03 +0000 Subject: [PATCH 01/13] support transformer_layers_per block in flax UNet --- src/diffusers/models/unet_2d_blocks_flax.py | 9 ++++++--- src/diffusers/models/unet_2d_condition_flax.py | 15 +++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 0d1447570dda..8c24b9f264b0 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 def setup(self): resnets = [] @@ -72,7 +73,7 @@ def setup(self): in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, - depth=1, + depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, @@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 def setup(self): resnets = [] @@ -213,7 +215,7 @@ def setup(self): in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, - depth=1, + depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, @@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): use_linear_projection: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 def setup(self): # there is always at least one resnet @@ -350,7 +353,7 @@ def setup(self): in_channels=self.in_channels, n_heads=self.num_attention_heads, d_head=self.in_channels // self.num_attention_heads, - depth=1, + depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index de39bc75d2e3..d3ce0f8cdce6 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -116,6 +116,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False + transformer_layers_per_block: Union[int, Tuple[int]] = 1 + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors @@ -168,6 +170,11 @@ def setup(self): if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) + # transformer layers per block + transformer_layers_per_block = self.transformer_layers_per_block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) + # down down_blocks = [] output_channel = block_out_channels[0] @@ -182,12 +189,13 @@ def setup(self): out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, - dtype=self.dtype, + dtype=self.dtype ) else: down_block = FlaxDownBlock2D( @@ -207,9 +215,10 @@ def setup(self): in_channels=block_out_channels[-1], dropout=self.dropout, num_attention_heads=num_attention_heads[-1], + transformer_layers_per_block=transformer_layers_per_block[-1], use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, - dtype=self.dtype, + dtype=self.dtype ) # up @@ -218,6 +227,7 @@ def setup(self): reversed_num_attention_heads = list(reversed(num_attention_heads)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] @@ -231,6 +241,7 @@ def setup(self): out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], num_attention_heads=reversed_num_attention_heads[i], add_upsample=not is_final_block, dropout=self.dropout, From 4c78659d744fc6edcf354800b21afc23c66196d5 Mon Sep 17 00:00:00 2001 From: Martin Muller Date: Wed, 12 Jul 2023 15:36:26 +0200 Subject: [PATCH 02/13] add support for text_time additional embeddings to Flax UNet --- src/diffusers/models/unet_2d_condition.py | 3 +- .../models/unet_2d_condition_flax.py | 50 +++++++++++++++++-- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index dee71bead0f9..d21365179728 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -205,7 +205,7 @@ def __init__( class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, - addition_embed_type_num_heads=64, + addition_embed_type_num_heads = 64, ): super().__init__() @@ -848,7 +848,6 @@ def forward( time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index d3ce0f8cdce6..2e718bf97069 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Dict import flax import flax.linen as nn @@ -117,7 +117,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): freq_shift: int = 0 use_memory_efficient_attention: bool = False transformer_layers_per_block: Union[int, Tuple[int]] = 1 - + addition_embed_type: Optional[str] = None + addition_time_embed_dim: Optional[int] = None + addition_embed_type_num_heads: int = 64 + projection_class_embeddings_input_dim: Optional[int] = None def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors @@ -129,7 +132,13 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + added_cond_kwargs = None + if self.addition_embed_type == 'text_time': + added_cond_kwargs = { + 'text_embeds': jnp.zeros((1, 1280), dtype=jnp.float32), # TODO: Check where this is comming from - block_out_channels[-1]? + 'time_ids': jnp.zeros((1, 6), dtype=jnp.float32) + } + return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] def setup(self): block_out_channels = self.block_out_channels @@ -175,6 +184,17 @@ def setup(self): if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) + # addition embed types + if self.addition_embed_type is None: + self.add_embedding = None + elif self.addition_embed_type == 'text_time': + if self.addition_time_embed_dim is None: + raise ValueError(f'addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None') + self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) + self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + else: + raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") + # down down_blocks = [] output_channel = block_out_channels[0] @@ -280,6 +300,7 @@ def __call__( sample, timesteps, encoder_hidden_states, + added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, down_block_additional_residuals=None, mid_block_additional_residual=None, return_dict: bool = True, @@ -311,6 +332,29 @@ def __call__( t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) + # additional embeddings + aug_emb = None + if self.addition_embed_type == 'text_time': + if added_cond_kwargs is None: + raise ValueError(f'Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`') + text_embeds = added_cond_kwargs.get("text_embeds") + if text_embeds is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + if time_ids is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + # compute time embeds + time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256) + time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1)) + add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1) + aug_emb = self.add_embedding(add_embeds) + + t_emb = t_emb + aug_emb if aug_emb is not None else t_emb + # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) From 3aa31641b8e80df5c8bcc82cc677a547d91cdc9e Mon Sep 17 00:00:00 2001 From: Martin Muller Date: Wed, 12 Jul 2023 22:00:01 +0200 Subject: [PATCH 03/13] rename attention layers for VAE --- src/diffusers/models/modeling_flax_pytorch_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py index f9de83f87dab..3e2ba05d4d88 100644 --- a/src/diffusers/models/modeling_flax_pytorch_utils.py +++ b/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -42,9 +42,18 @@ def rename_key(key): # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" - # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + + # rename attention layers + if len(pt_tuple_key) > 1: + for rename_from, rename_to in (('to_out_0', 'proj_attn'), ('to_k', 'key'), ('to_v', 'value'), ('to_q', 'query')): + if pt_tuple_key[-2] == rename_from: + weight_name = 'kernel' if pt_tuple_key[-1] == 'weight' else 'bias' + renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) + if renamed_pt_tuple_key in random_flax_state_dict: + return renamed_pt_tuple_key, pt_tensor + if ( any("norm" in str_ for str_ in pt_tuple_key) and (pt_tuple_key[-1] == "bias") From bc267e4b2b04f5ad7223b54109fe88fde25d6232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20M=C3=BCller?= Date: Thu, 13 Jul 2023 20:34:36 +0000 Subject: [PATCH 04/13] add shape asserts when renaming attention layers --- src/diffusers/models/modeling_flax_pytorch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py index 3e2ba05d4d88..8bd631f797bd 100644 --- a/src/diffusers/models/modeling_flax_pytorch_utils.py +++ b/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -49,9 +49,11 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic if len(pt_tuple_key) > 1: for rename_from, rename_to in (('to_out_0', 'proj_attn'), ('to_k', 'key'), ('to_v', 'value'), ('to_q', 'query')): if pt_tuple_key[-2] == rename_from: - weight_name = 'kernel' if pt_tuple_key[-1] == 'weight' else 'bias' + weight_name = pt_tuple_key[-1] + weight_name = 'kernel' if weight_name == 'weight' else weight_name renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) if renamed_pt_tuple_key in random_flax_state_dict: + assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.shape return renamed_pt_tuple_key, pt_tensor if ( From ea0e675e1007a737a93af7fa80daffca6e3dadd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20M=C3=BCller?= Date: Fri, 14 Jul 2023 21:13:51 +0000 Subject: [PATCH 05/13] transpose VAE attention layers --- src/diffusers/models/modeling_flax_pytorch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py index 8bd631f797bd..e92556da4e85 100644 --- a/src/diffusers/models/modeling_flax_pytorch_utils.py +++ b/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -53,8 +53,8 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic weight_name = 'kernel' if weight_name == 'weight' else weight_name renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) if renamed_pt_tuple_key in random_flax_state_dict: - assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.shape - return renamed_pt_tuple_key, pt_tensor + assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape + return renamed_pt_tuple_key, pt_tensor.T if ( any("norm" in str_ for str_ in pt_tuple_key) From 1cc6c37bb06917600cab0cb4a6041eba5e68411b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20M=C3=BCller?= Date: Mon, 17 Jul 2023 15:33:54 +0000 Subject: [PATCH 06/13] add pipeline flax SDXL code [WIP] --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 3 + .../pipelines/stable_diffusion_xl/__init__.py | 13 +- .../pipeline_flax_stable_diffusion_xl.py | 158 ++++++++++++++++++ 4 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a7fc9a36f271..045a2d989495 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -270,6 +270,7 @@ FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline ) try: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c3968406ed90..0d5c93bd0aab 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -173,6 +173,9 @@ FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) + from .stable_diffusion_xl import ( + FlaxStableDiffusionXLPipeline, + ) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index d61ba9fab3a3..4d9e8fe5b5f0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -4,7 +4,13 @@ import numpy as np import PIL -from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available +from ...utils import ( + BaseOutput, + is_invisible_watermark_available, + is_torch_available, + is_transformers_available, + is_flax_available, +) @dataclass @@ -24,3 +30,8 @@ class StableDiffusionXLPipelineOutput(BaseOutput): if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline + + +if is_flax_available(): + from .pipeline_flax_stable_diffusion_xl import FlaxStableDiffusionXLPipeline + diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py new file mode 100644 index 000000000000..cd14b1870e79 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -0,0 +1,158 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union, Dict, Optional +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPTokenizer, FlaxCLIPTextModel +import jax.numpy as jnp +import jax + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from diffusers.utils import logging + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + + +class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): + def __init__( + self, + text_encoder: FlaxCLIPTextModel, + text_encoder_2: FlaxCLIPTextModel, + vae: FlaxAutoencoderKL, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if self.tokenizer is not None: + assert self.tokenizer_2 is not None + tokenizers = [self.tokenizer, self.tokenizer_2] + else: + tokenizers = [self.tokenizer_2] + inputs = [] + for tokenizer in enumerate(tokenizers): + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np" + ) + inputs.append(text_inputs.input_ids) + inputs = jnp.stack(inputs) + return inputs + + def __call__( + self, + prompt_ids: jax.Array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + guidance_scale: Union[float, jax.Array] = 7.5, + height: Optional[int] = None, + width: Optional[int] = None, + ): + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + # TODO: support jit + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale + ) + + def get_embeddings(self, prompt_ids: jax.Array, params: Union[Dict, FrozenDict]): + if prompt_ids.shape[0] == 2: + # using both CLIP models + prompt_embeds = self.text_encoder(prompt_ids[0], params=params['text_encoder'], output_hidden_states=True) + prompt_embeds = prompt_embeds['hidden_states'][-2] + prompt_embeds_2 = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) + prompt_embeds_2 = prompt_embeds_2['hidden_states'][-2] + else: + prompt_embeds = jnp.array([]) # dummy embedding for first CLIP model + prompt_embeds_2 = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) + prompt_embeds_2 = prompt_embeds_2['hidden_states'][-2] + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) + return prompt_embeds + + def _generate( + self, + prompt_ids: jax.Array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + neg_prompt_ids: Optional[jax.Array] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # 1. Encode input prompt + prompt_embeds = self.get_embeddings(prompt_ids, params) + + # 2. Get unconditional embeddings + batch_size = prompt_embeds.shape[0] + if neg_prompt_ids is None: + neg_prompt_ids = self.prepare_inputs([""] * batch_size) + neg_prompt_embeds = self.get_embeddings(neg_prompt_ids, params) + + context = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) From 9771feb380d1a58636ca511cebc536d0ad57abf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20M=C3=BCller?= Date: Tue, 18 Jul 2023 00:08:29 +0000 Subject: [PATCH 07/13] continue add pipeline flax SDXL code [WIP] --- .../pipeline_flax_stable_diffusion_xl.py | 75 ++++++++++++++++--- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index cd14b1870e79..42935848c15d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -98,6 +98,7 @@ def __call__( # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for @@ -115,7 +116,8 @@ def __call__( num_inference_steps, height, width, - guidance_scale + guidance_scale, + do_classifier_free_guidance ) def get_embeddings(self, prompt_ids: jax.Array, params: Union[Dict, FrozenDict]): @@ -123,14 +125,20 @@ def get_embeddings(self, prompt_ids: jax.Array, params: Union[Dict, FrozenDict]) # using both CLIP models prompt_embeds = self.text_encoder(prompt_ids[0], params=params['text_encoder'], output_hidden_states=True) prompt_embeds = prompt_embeds['hidden_states'][-2] - prompt_embeds_2 = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) - prompt_embeds_2 = prompt_embeds_2['hidden_states'][-2] + prompt_embeds_2_out = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) + prompt_embeds_2 = prompt_embeds_2_out['hidden_states'][-2] else: prompt_embeds = jnp.array([]) # dummy embedding for first CLIP model - prompt_embeds_2 = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) - prompt_embeds_2 = prompt_embeds_2['hidden_states'][-2] + prompt_embeds_2_out = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) + prompt_embeds_2 = prompt_embeds_2_out['hidden_states'][-2] + pooled_embeds = prompt_embeds_2_out['pooler_output'] # use second text encoder's pooled output prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) - return prompt_embeds + return prompt_embeds, pooled_embeds + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) # TODO: This is currently not jit'able - probably need pass add_time_ids as input to __call__ + add_time_ids = jnp.array([add_time_ids], dtype=dtype) + return add_time_ids def _generate( self, @@ -140,19 +148,66 @@ def _generate( num_inference_steps: int, height: int, width: int, - guidance_scale: float, + guidance_scale: float = 7.5, + do_classifier_free_guidance: bool = True, + latents: Optional[jax.Array] = None, neg_prompt_ids: Optional[jax.Array] = None, + original_size: tuple = (1024, 1024), + crops_coords_top_left: tuple = (0, 0), + target_size: tuple = (1024, 1024), ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # 1. Encode input prompt - prompt_embeds = self.get_embeddings(prompt_ids, params) + prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params) # 2. Get unconditional embeddings batch_size = prompt_embeds.shape[0] if neg_prompt_ids is None: neg_prompt_ids = self.prepare_inputs([""] * batch_size) - neg_prompt_embeds = self.get_embeddings(neg_prompt_ids, params) - + # TODO: properly support without classifier guidance here (or drop support entirely) + neg_prompt_embeds, pooled_neg_embeds = self.get_embeddings(neg_prompt_ids, params) context = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) + + # 3. Create random latents + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + # 4. Prepare scheduler state + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # 5. Prepare added embeddings + add_text_embeds = prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + # 6. Denoising loop + def loop_body(step, args): + # TODO + pass + __import__('pdb').set_trace() + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # 7. Deocde latents + # TODO From 1137263d0fe09c691bc2ba20e01e1935d8f8e22a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20M=C3=BCller?= Date: Tue, 18 Jul 2023 00:46:33 +0000 Subject: [PATCH 08/13] cleanup --- src/diffusers/models/unet_2d_condition_flax.py | 6 +++--- .../pipeline_flax_stable_diffusion_xl.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 2e718bf97069..acfc84d5bfff 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -135,7 +135,7 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: added_cond_kwargs = None if self.addition_embed_type == 'text_time': added_cond_kwargs = { - 'text_embeds': jnp.zeros((1, 1280), dtype=jnp.float32), # TODO: Check where this is comming from - block_out_channels[-1]? + 'text_embeds': jnp.zeros((1, 1280), dtype=jnp.float32), # TODO: This should be set based on config 'time_ids': jnp.zeros((1, 6), dtype=jnp.float32) } return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] @@ -215,7 +215,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, - dtype=self.dtype + dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( @@ -238,7 +238,7 @@ def setup(self): transformer_layers_per_block=transformer_layers_per_block[-1], use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, - dtype=self.dtype + dtype=self.dtype, ) # up diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 42935848c15d..e3de38d81c2d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -200,7 +200,6 @@ def _generate( def loop_body(step, args): # TODO pass - __import__('pdb').set_trace() if DEBUG: # run with python for loop From f02d7958264666040ba5f34890a3fd2352da42a0 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 25 Jul 2023 07:28:14 +0000 Subject: [PATCH 09/13] Working on JIT support Fixed prompt embedding shapes so they work in parallel mode. Assuming we always have both text encoders for now, for simplicity. --- .../pipelines/stable_diffusion_xl/__init__.py | 17 +- .../pipeline_flax_stable_diffusion_xl.py | 161 +++++++++++++----- 2 files changed, 136 insertions(+), 42 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 4d9e8fe5b5f0..c25f83b9a792 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -32,6 +32,19 @@ class StableDiffusionXLPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline -if is_flax_available(): - from .pipeline_flax_stable_diffusion_xl import FlaxStableDiffusionXLPipeline +if is_transformers_available() and is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Flax Stable Diffusion XL pipelines. + Args: + images (`np.ndarray`) + Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. + """ + images: np.ndarray + + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + from .pipeline_flax_stable_diffusion_xl import FlaxStableDiffusionXLPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index e3de38d81c2d..0b7e8ddb514a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import List, Union, Dict, Optional from flax.core.frozen_dict import FrozenDict from transformers import CLIPTokenizer, FlaxCLIPTextModel @@ -26,6 +27,8 @@ FlaxPNDMScheduler, ) from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionXLPipelineOutput + from diffusers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -66,14 +69,10 @@ def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if self.tokenizer is not None: - assert self.tokenizer_2 is not None - tokenizers = [self.tokenizer, self.tokenizer_2] - else: - tokenizers = [self.tokenizer_2] + # Assume we have the two encoders inputs = [] - for tokenizer in enumerate(tokenizers): - text_inputs = self.tokenizer( + for tokenizer in [self.tokenizer, self.tokenizer_2]: + text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, @@ -81,7 +80,7 @@ def prepare_inputs(self, prompt: Union[str, List[str]]): return_tensors="np" ) inputs.append(text_inputs.input_ids) - inputs = jnp.stack(inputs) + inputs = jnp.stack(inputs, axis=1) return inputs def __call__( @@ -93,12 +92,15 @@ def __call__( guidance_scale: Union[float, jax.Array] = 7.5, height: Optional[int] = None, width: Optional[int] = None, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + return_dict: bool = True, + jit: bool = False, ): # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for @@ -108,30 +110,47 @@ def __call__( # Assume sharded guidance_scale = guidance_scale[:, None] - # TODO: support jit - images = self._generate( - prompt_ids, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - do_classifier_free_guidance - ) - - def get_embeddings(self, prompt_ids: jax.Array, params: Union[Dict, FrozenDict]): - if prompt_ids.shape[0] == 2: - # using both CLIP models - prompt_embeds = self.text_encoder(prompt_ids[0], params=params['text_encoder'], output_hidden_states=True) - prompt_embeds = prompt_embeds['hidden_states'][-2] - prompt_embeds_2_out = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) - prompt_embeds_2 = prompt_embeds_2_out['hidden_states'][-2] + if jit: + images = _p_generate( + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) else: - prompt_embeds = jnp.array([]) # dummy embedding for first CLIP model - prompt_embeds_2_out = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) - prompt_embeds_2 = prompt_embeds_2_out['hidden_states'][-2] - pooled_embeds = prompt_embeds_2_out['pooler_output'] # use second text encoder's pooled output + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + if not return_dict: + return (images,) + + return FlaxStableDiffusionXLPipelineOutput(images=images) + + def get_embeddings(self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict]): + # We assume we have the two encoders + # [2, 77] -> [2, 1, 77] + prompt_ids = jnp.expand_dims(prompt_ids, axis=-2) + + prompt_embeds = self.text_encoder(prompt_ids[0], params=params['text_encoder'], output_hidden_states=True) + prompt_embeds = prompt_embeds['hidden_states'][-2] + prompt_embeds_2_out = self.text_encoder_2(prompt_ids[1], params=params['text_encoder_2'], output_hidden_states=True) + prompt_embeds_2 = prompt_embeds_2_out['hidden_states'][-2] + pooled_embeds = prompt_embeds_2_out['pooler_output'] prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) return prompt_embeds, pooled_embeds @@ -142,16 +161,15 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d def _generate( self, - prompt_ids: jax.Array, + prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int, height: int, width: int, - guidance_scale: float = 7.5, - do_classifier_free_guidance: bool = True, - latents: Optional[jax.Array] = None, - neg_prompt_ids: Optional[jax.Array] = None, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, original_size: tuple = (1024, 1024), crops_coords_top_left: tuple = (0, 0), target_size: tuple = (1024, 1024), @@ -166,10 +184,14 @@ def _generate( batch_size = prompt_embeds.shape[0] if neg_prompt_ids is None: neg_prompt_ids = self.prepare_inputs([""] * batch_size) + # TODO: properly support without classifier guidance here (or drop support entirely) neg_prompt_embeds, pooled_neg_embeds = self.get_embeddings(neg_prompt_ids, params) context = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + # 3. Create random latents latents_shape = ( batch_size, @@ -196,10 +218,36 @@ def _generate( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # 6. Denoising loop def loop_body(step, args): - # TODO - pass + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + added_cond_kwargs=added_cond_kwargs, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state if DEBUG: # run with python for loop @@ -208,5 +256,38 @@ def loop_body(step, args): else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) - # 7. Deocde latents + # 7. Decode latents # TODO + return latents + +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 4, 5, 6), +) +def _p_generate( + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + From ff46fa2082de9927ec2ca0a5e5d082ebfe2b831f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 25 Jul 2023 11:27:56 +0000 Subject: [PATCH 10/13] Fixing embeddings (untested) --- .../pipeline_flax_stable_diffusion_xl.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 0b7e8ddb514a..a30452295629 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -185,9 +185,19 @@ def _generate( if neg_prompt_ids is None: neg_prompt_ids = self.prepare_inputs([""] * batch_size) - # TODO: properly support without classifier guidance here (or drop support entirely) - neg_prompt_embeds, pooled_neg_embeds = self.get_embeddings(neg_prompt_ids, params) - context = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) + neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) + + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + + print(f"prompt_embeds: {prompt_embeds.shape} (2, 77, 2048)") + print(f"add_text_embeds: {add_text_embeds.shape} (2, 1280)") + print(f"add_time_ids: {add_time_ids.shape} (2, 6)") # Ensure model output will be `float32` before going into the scheduler guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) @@ -214,9 +224,6 @@ def _generate( # 5. Prepare added embeddings add_text_embeds = prompt_embeds - add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype - ) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -238,7 +245,7 @@ def loop_body(step, args): {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, + encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance From dfc3c81aebd031d651ceed1d631f4bbe1ce78781 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 25 Jul 2023 12:55:53 +0000 Subject: [PATCH 11/13] Remove spurious line --- .../stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index a30452295629..ad71d13fde30 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -222,9 +222,6 @@ def _generate( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) - # 5. Prepare added embeddings - add_text_embeds = prompt_embeds - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # 6. Denoising loop From 484e516307da9489d6560d9cce3eb9b1d0afae51 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 25 Jul 2023 14:31:42 +0000 Subject: [PATCH 12/13] Shard guidance_scale when jitting. --- .../pipeline_flax_stable_diffusion_xl.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index ad71d13fde30..ffb42742b6d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -102,13 +102,10 @@ def __call__( height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor - if isinstance(guidance_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. + if isinstance(guidance_scale, float) and jit: + # Convert to a tensor so each device gets a copy. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - guidance_scale = guidance_scale[:, None] + guidance_scale = guidance_scale[:, None] if jit: images = _p_generate( @@ -195,10 +192,6 @@ def _generate( add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) - print(f"prompt_embeds: {prompt_embeds.shape} (2, 77, 2048)") - print(f"add_text_embeds: {add_text_embeds.shape} (2, 1280)") - print(f"add_time_ids: {add_time_ids.shape} (2, 6)") - # Ensure model output will be `float32` before going into the scheduler guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) From 70e1058b253df324fece4e2a12071ae41e18d110 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 25 Jul 2023 15:07:48 +0000 Subject: [PATCH 13/13] Decode images --- .../pipeline_flax_stable_diffusion_xl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index ffb42742b6d4..339159c40669 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -254,8 +254,12 @@ def loop_body(step, args): latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # 7. Decode latents - # TODO - return latents + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).