-
Notifications
You must be signed in to change notification settings - Fork 6.6k
SDXL flax #4254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDXL flax #4254
Conversation
Fixed prompt embedding shapes so they work in parallel mode. Assuming we always have both text encoders for now, for simplicity.
|
The documentation is not available anymore as the PR was closed or merged. |
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
src/diffusers/__init__.py
Outdated
| FlaxStableDiffusionImg2ImgPipeline, | ||
| FlaxStableDiffusionInpaintPipeline, | ||
| FlaxStableDiffusionPipeline, | ||
| FlaxStableDiffusionXLPipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: verify dummy clases when transformers and/or flax are not installed.
|
I found there are 2 reasons why sharding doesn't work (imgs_per_device > 1). First is the the get_embeddings function is returning a tensor for a single image. I think prompt_ids dims need to be swapped and passed to the text encoder so that it takes the batch size per device. An extremely hacky way I did for my testing with images_per_device=4. def get_embeddings(self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict]):
# We assume we have the two encoders
# [2, 77] -> [2, 1, 77]
prompt_ids = jnp.expand_dims(prompt_ids, axis=-2)
prompt_embeds = self.text_encoder(prompt_ids[0][0], params=params['text_encoder'], output_hidden_states=True)
prompt_embeds = prompt_embeds['hidden_states'][-2]
prompt_embeds_2_out = self.text_encoder_2(prompt_ids[0][1], params=params['text_encoder_2'], output_hidden_states=True)
prompt_embeds_2 = prompt_embeds_2_out['hidden_states'][-2]
pooled_embeds = prompt_embeds_2_out['pooler_output']
prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
prompt_embeds = jnp.squeeze(jnp.stack([prompt_embeds] * 4, axis=1))
pooled_embeds = jnp.squeeze(jnp.stack([pooled_embeds] * 4, axis=1))
return prompt_embeds, pooled_embedsThe other is the add_time_ids needs to be multiplied by BS. def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size) # TODO: This is currently not jit'able - probably need pass add_time_ids as input to __call__
add_time_ids = jnp.array([add_time_ids]*4, dtype=dtype)
return add_time_idsAfter these changes, I was able to get sharding working. I also wrote my own class FlaxCLIPTextModuleWithProjection(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.bfloat16
def setup(self):
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
return outputsHope this helps! |
| transformer_layers_per_block: Union[int, Tuple[int]] = 1 | ||
| addition_embed_type: Optional[str] = None | ||
| addition_time_embed_dim: Optional[int] = None | ||
| addition_embed_type_num_heads: int = 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this num heads or "head_dim"?
Fixed a couple of bugs there were in the euler scheduler and aligned it fully with PyTorch. However, I'm still observing the same problem as noted by @entrpn . When casting the scheduler params to bfloat16, the sigma values get "too" incorrect and the final step is a "nan" value. If one leaves the scheduler params in float32 the DiscreteEulerScheduler can work if ~100 inference steps are used. For just 25 inference steps I also most of the time get noise. To me it looks like SDXL + TPU + DiscreteEuler generates some "extreme" values for which bfloat16 is simply not precise enough. Increasing the number of inference steps seems to smooth the euler derivative so that reasonable images can be generated as follows: import jax
import jax.numpy as jnp
import numpy as np
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionXLPipeline, FlaxEulerDiscreteScheduler, FlaxDPMSolverMultistepScheduler
dtype = jnp.bfloat16
model_id = "pcuenq/stable-diffusion-xl-base-1.0-flax"
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(model_id, dtype=dtype)
scheduler_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler, state = FlaxEulerDiscreteScheduler.from_pretrained(scheduler_id, subfolder="scheduler")
# scheduler, state = FlaxDPMSolverMultistepScheduler.from_pretrained(scheduler_id, subfolder="scheduler")
pipeline.scheduler = scheduler
params = jax.tree_util.tree_map(lambda x: x.astype(dtype), params)
params["scheduler"] = state
imgs_per_device = 1
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
prompt = [prompt] * jax.device_count() * imgs_per_device
prompt_ids = pipeline.prepare_inputs(prompt)
neg_prompt = ""
neg_prompt = [neg_prompt] * jax.device_count() * imgs_per_device
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
neg_prompt_ids = shard(neg_prompt_ids)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
rng = create_key(30)
rng = jax.random.split(rng, jax.device_count())
do_jit = True
def generate(prompt_ids, neg_prompt_ids):
images = pipeline(
prompt_ids if do_jit else prompt_ids[0],
p_params if do_jit else params,
rng if do_jit else rng[0],
num_inference_steps=100,
neg_prompt_ids=neg_prompt_ids if do_jit else neg_prompt_ids[0],
guidance_scale = 9.,
jit=do_jit,
).images
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
import time
start = time.time()
_ = generate(prompt_ids, neg_prompt_ids)
print(f"Compiled in {time.time() - start}")
start = time.time()
images = generate(prompt_ids, neg_prompt_ids)
print(f"Inference in {time.time() - start}")
for i, image in enumerate(images):
image.save(f"image_{i}.png")=> I actually don't really think here that the euler scheduler is wrong in Flax, I just thing the values are too extreme on TPU and we should maybe avoid using it and instead settle on the faster |
| dtype=dtype, | ||
| **kwargs, | ||
| ) | ||
| # Load config if we don't provide one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a refactor I should have done a while ago
| @@ -1,460 +1,473 @@ | |||
| from typing import TYPE_CHECKING | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
git diff does a horrible job here as the init is not changed that much. I checked it manually and it should be good
| @@ -0,0 +1,396 @@ | |||
| # Copyright 2023 The HuggingFace Team. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcuenca does this work already? Otherwise we could also move it out of this PR and do it in a follow up PR as it shouldn't really block anything for the demo
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | ||
| add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) | ||
|
|
||
| # passed_add_embed_dim = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments should be removed or treated probably
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok to merge for me as soon as:
- Remove the hardcoded values in the SDXL T2I pipeline
- Remove or clean the StableDiffusionXLImg2Img pipeline. Think it doesn't work yet so maybe we can just remove it for now and open a "TODO" issue for it?
src/diffusers/pipelines/__init__.py
Outdated
| _import_structure["stable_diffusion_xl"].extend( | ||
| [ | ||
| "FlaxStableDiffusionXLPipeline", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this doesn't belong in this section
|
Merging! Failing test is flaky |
|
Nice investigation about the Euler scheduler! |
* support transformer_layers_per block in flax UNet * add support for text_time additional embeddings to Flax UNet * rename attention layers for VAE * add shape asserts when renaming attention layers * transpose VAE attention layers * add pipeline flax SDXL code [WIP] * continue add pipeline flax SDXL code [WIP] * cleanup * Working on JIT support Fixed prompt embedding shapes so they work in parallel mode. Assuming we always have both text encoders for now, for simplicity. * Fixing embeddings (untested) * Remove spurious line * Shard guidance_scale when jitting. * Decode images * Fix sharding * style * Refiner UNet can be loaded. * Refiner / img2img pipeline * Allow latent outputs from base and latent inputs in refiner This makes it possible to chain base + refiner without having to use the vae decoder in the base model, the vae encoder in the refiner, skipping conversions to/from PIL, and avoiding TPU <-> CPU memory copies. * Adapt to FlaxCLIPTextModelOutput * Update Flax XL pipeline to FlaxCLIPTextModelOutput * make fix-copies * make style * add euler scheduler * Fix import * Fix copies, comment unused code. * Fix SDXL Flax imports * Fix euler discrete begin * improve init import * finish * put discrete euler in init * fix flax euler * Fix more * make style * correct init * correct init * Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline * correct pipelines * finish --------- Co-authored-by: Martin Müller <[email protected]> Co-authored-by: patil-suraj <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
* support transformer_layers_per block in flax UNet * add support for text_time additional embeddings to Flax UNet * rename attention layers for VAE * add shape asserts when renaming attention layers * transpose VAE attention layers * add pipeline flax SDXL code [WIP] * continue add pipeline flax SDXL code [WIP] * cleanup * Working on JIT support Fixed prompt embedding shapes so they work in parallel mode. Assuming we always have both text encoders for now, for simplicity. * Fixing embeddings (untested) * Remove spurious line * Shard guidance_scale when jitting. * Decode images * Fix sharding * style * Refiner UNet can be loaded. * Refiner / img2img pipeline * Allow latent outputs from base and latent inputs in refiner This makes it possible to chain base + refiner without having to use the vae decoder in the base model, the vae encoder in the refiner, skipping conversions to/from PIL, and avoiding TPU <-> CPU memory copies. * Adapt to FlaxCLIPTextModelOutput * Update Flax XL pipeline to FlaxCLIPTextModelOutput * make fix-copies * make style * add euler scheduler * Fix import * Fix copies, comment unused code. * Fix SDXL Flax imports * Fix euler discrete begin * improve init import * finish * put discrete euler in init * fix flax euler * Fix more * make style * correct init * correct init * Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline * correct pipelines * finish --------- Co-authored-by: Martin Müller <[email protected]> Co-authored-by: patil-suraj <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>


Continuation of #4136, cc @mar-muel
Current status: pipeline works end-to-end on TPU v4, it takes ~2.7s for 4 images.
FlaxCLIPTextModelWithProjection:variantdoes not work in Flax, I think)Temporarily pushed the weights to https://huggingface.co/pcuenq/stable-diffusion-xl-base-1.0-flax and https://huggingface.co/pcuenq/stable-diffusion-xl-refiner-1.0-flax. They are in float32, I think
variantis not supported in Flax.