Skip to content

Conversation

@pcuenca
Copy link
Member

@pcuenca pcuenca commented Jul 25, 2023

Continuation of #4136, cc @mar-muel

Current status: pipeline works end-to-end on TPU v4, it takes ~2.7s for 4 images.

  • Sharding bug
  • Refiner UNet
  • Refiner (Img2Img) pipeline
  • Transformers PR to support FlaxCLIPTextModelWithProjection:
  • Push weights to separate repos (variant does not work in Flax, I think)
  • Docs and tests

Temporarily pushed the weights to https://huggingface.co/pcuenq/stable-diffusion-xl-base-1.0-flax and https://huggingface.co/pcuenq/stable-diffusion-xl-refiner-1.0-flax. They are in float32, I think variant is not supported in Flax.

import jax
import jax.numpy as jnp
import numpy as np

num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionXLPipeline

dtype = jnp.bfloat16
model_id = "pcuenq/stable-diffusion-xl-base-1.0-flax"

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    model_id,
    dtype=dtype,
)
params = jax.tree_util.tree_map(lambda x: x.astype(dtype), params)

imgs_per_device = 1

prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
prompt = [prompt] * jax.device_count() * imgs_per_device
prompt_ids = pipeline.prepare_inputs(prompt)

neg_prompt = "fog, grainy, purple"
neg_prompt = [neg_prompt] * jax.device_count() * imgs_per_device
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
neg_prompt_ids = shard(neg_prompt_ids)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

do_jit = True

def generate(prompt_ids, neg_prompt_ids):
    images = pipeline(
        prompt_ids if do_jit else prompt_ids[0],
        p_params if do_jit else params,
        rng if do_jit else rng[0],
        num_inference_steps=20,
        neg_prompt_ids=neg_prompt_ids if do_jit else neg_prompt_ids[0],
        guidance_scale = 9.,
        jit=do_jit,
    ).images
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

import time

start = time.time()
_ = generate(prompt_ids, neg_prompt_ids)
print(f"Compiled in {time.time() - start}")

start = time.time()
images = generate(prompt_ids, neg_prompt_ids)
print(f"Inference in {time.time() - start}")

for i, image in enumerate(images):
    image.save(f"castle_{i}.png")

@pcuenca pcuenca marked this pull request as draft July 25, 2023 07:36
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 25, 2023

The documentation is not available anymore as the PR was closed or merged.

FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: verify dummy clases when transformers and/or flax are not installed.

@pcuenca pcuenca marked this pull request as ready for review July 25, 2023 15:54
@entrpn
Copy link
Contributor

entrpn commented Jul 27, 2023

I found there are 2 reasons why sharding doesn't work (imgs_per_device > 1).

First is the the get_embeddings function is returning a tensor for a single image. I think prompt_ids dims need to be swapped and passed to the text encoder so that it takes the batch size per device. An extremely hacky way I did for my testing with images_per_device=4.

    def get_embeddings(self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict]):
        # We assume we have the two encoders
        # [2, 77] -> [2, 1, 77]
        prompt_ids = jnp.expand_dims(prompt_ids, axis=-2)
        prompt_embeds = self.text_encoder(prompt_ids[0][0], params=params['text_encoder'], output_hidden_states=True)
        prompt_embeds = prompt_embeds['hidden_states'][-2]
        prompt_embeds_2_out = self.text_encoder_2(prompt_ids[0][1], params=params['text_encoder_2'], output_hidden_states=True)
        prompt_embeds_2 = prompt_embeds_2_out['hidden_states'][-2]
        pooled_embeds = prompt_embeds_2_out['pooler_output']
        prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
        prompt_embeds = jnp.squeeze(jnp.stack([prompt_embeds] * 4, axis=1))
        pooled_embeds = jnp.squeeze(jnp.stack([pooled_embeds] * 4, axis=1))
        return prompt_embeds, pooled_embeds

The other is the add_time_ids needs to be multiplied by BS.

    def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
        add_time_ids = list(original_size + crops_coords_top_left + target_size)  # TODO: This is currently not jit'able - probably need pass add_time_ids as input to __call__
        add_time_ids = jnp.array([add_time_ids]*4, dtype=dtype)
        return add_time_ids

After these changes, I was able to get sharding working.

I also wrote my own FlaxCLIPTextModuleWithProjection and it seems to be working.

class FlaxCLIPTextModuleWithProjection(nn.Module):
    config: CLIPTextConfig
    dtype: jnp.dtype = jnp.bfloat16

    def setup(self):
        self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
        self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False)

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = text_outputs[1]

        text_embeds = self.text_projection(pooled_output)

        outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]

        return outputs

Hope this helps!

transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this num heads or "head_dim"?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Sep 22, 2023

@pcuenca @patil-suraj I think there's something wrong with the Euler scheduler. if I run through it with bf16 conversion, I get undefined latents for the last step, this seems to come from the alpha_cumsums getting rounded to 1 during bf16 conversion which makes sigma=0 and in the step function there is a division by sigma for the last step which causes undefined latents.

If I don't set the scheduler to bf16, then I get random artifacts in the image. See below.

castle_4

I changed the scheduler in model_index.json from FlaxEulerDiscreteScheduler to FlaxDPMSolverMultistepScheduler and generation works even with bf16 conversion of the scheduler values.

Fixed a couple of bugs there were in the euler scheduler and aligned it fully with PyTorch.
Also tested that the euler scheduler gives exactly the same results for a dummy model when comparing Flax and PyTorch on CPU.

However, I'm still observing the same problem as noted by @entrpn . When casting the scheduler params to bfloat16, the sigma values get "too" incorrect and the final step is a "nan" value. If one leaves the scheduler params in float32 the DiscreteEulerScheduler can work if ~100 inference steps are used. For just 25 inference steps I also most of the time get noise.

To me it looks like SDXL + TPU + DiscreteEuler generates some "extreme" values for which bfloat16 is simply not precise enough. Increasing the number of inference steps seems to smooth the euler derivative so that reasonable images can be generated as follows:

import jax
import jax.numpy as jnp
import numpy as np

num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionXLPipeline, FlaxEulerDiscreteScheduler, FlaxDPMSolverMultistepScheduler

dtype = jnp.bfloat16
model_id = "pcuenq/stable-diffusion-xl-base-1.0-flax"

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(model_id, dtype=dtype)

scheduler_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler, state = FlaxEulerDiscreteScheduler.from_pretrained(scheduler_id, subfolder="scheduler")
# scheduler, state = FlaxDPMSolverMultistepScheduler.from_pretrained(scheduler_id, subfolder="scheduler")
pipeline.scheduler = scheduler

params = jax.tree_util.tree_map(lambda x: x.astype(dtype), params)
params["scheduler"] = state

imgs_per_device = 1

prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
prompt = [prompt] * jax.device_count() * imgs_per_device
prompt_ids = pipeline.prepare_inputs(prompt)

neg_prompt = ""
neg_prompt = [neg_prompt] * jax.device_count() * imgs_per_device
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
neg_prompt_ids = shard(neg_prompt_ids)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(30)
rng = jax.random.split(rng, jax.device_count())

do_jit = True

def generate(prompt_ids, neg_prompt_ids):
    images = pipeline(
        prompt_ids if do_jit else prompt_ids[0],
        p_params if do_jit else params,
        rng if do_jit else rng[0],
        num_inference_steps=100,
        neg_prompt_ids=neg_prompt_ids if do_jit else neg_prompt_ids[0],
        guidance_scale = 9.,
        jit=do_jit,
    ).images
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

import time

start = time.time()
_ = generate(prompt_ids, neg_prompt_ids)
print(f"Compiled in {time.time() - start}")

start = time.time()
images = generate(prompt_ids, neg_prompt_ids)
print(f"Inference in {time.time() - start}")

for i, image in enumerate(images):
    image.save(f"image_{i}.png")

Ayv4INNFlF

=> I actually don't really think here that the euler scheduler is wrong in Flax, I just thing the values are too extreme on TPU and we should maybe avoid using it and instead settle on the faster FlaxDPMSolverMultistepScheduler.

Wdyt @pcuenca @entrpn

dtype=dtype,
**kwargs,
)
# Load config if we don't provide one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a refactor I should have done a while ago

@@ -1,460 +1,473 @@
from typing import TYPE_CHECKING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

git diff does a horrible job here as the init is not changed that much. I checked it manually and it should be good

@@ -0,0 +1,396 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca does this work already? Otherwise we could also move it out of this PR and do it in a follow up PR as it shouldn't really block anything for the demo

add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)

# passed_add_embed_dim = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments should be removed or treated probably

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok to merge for me as soon as:

  • Remove the hardcoded values in the SDXL T2I pipeline
  • Remove or clean the StableDiffusionXLImg2Img pipeline. Think it doesn't work yet so maybe we can just remove it for now and open a "TODO" issue for it?

Comment on lines 205 to 208
_import_structure["stable_diffusion_xl"].extend(
[
"FlaxStableDiffusionXLPipeline",
]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't belong in this section

@patrickvonplaten
Copy link
Contributor

Merging! Failing test is flaky

@patrickvonplaten patrickvonplaten merged commit 3651b14 into main Sep 22, 2023
@pcuenca
Copy link
Member Author

pcuenca commented Sep 22, 2023

Nice investigation about the Euler scheduler!

@patrickvonplaten patrickvonplaten deleted the sdxl-flax branch September 25, 2023 11:53
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* support transformer_layers_per block in flax UNet

* add support for text_time additional embeddings to Flax UNet

* rename attention layers for VAE

* add shape asserts when renaming attention layers

* transpose VAE attention layers

* add pipeline flax SDXL code [WIP]

* continue add pipeline flax SDXL code [WIP]

* cleanup

* Working on JIT support

Fixed prompt embedding shapes so they work in parallel mode. Assuming we
always have both text encoders for now, for simplicity.

* Fixing embeddings (untested)

* Remove spurious line

* Shard guidance_scale when jitting.

* Decode images

* Fix sharding

* style

* Refiner UNet can be loaded.

* Refiner / img2img pipeline

* Allow latent outputs from base and latent inputs in refiner

This makes it possible to chain base + refiner without having to use the
vae decoder in the base model, the vae encoder in the refiner, skipping
conversions to/from PIL, and avoiding TPU <-> CPU memory copies.

* Adapt to FlaxCLIPTextModelOutput

* Update Flax XL pipeline to FlaxCLIPTextModelOutput

* make fix-copies

* make style

* add euler scheduler

* Fix import

* Fix copies, comment unused code.

* Fix SDXL Flax imports

* Fix euler discrete begin

* improve init import

* finish

* put discrete euler in init

* fix flax euler

* Fix more

* make style

* correct init

* correct init

* Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline

* correct pipelines

* finish

---------

Co-authored-by: Martin Müller <[email protected]>
Co-authored-by: patil-suraj <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* support transformer_layers_per block in flax UNet

* add support for text_time additional embeddings to Flax UNet

* rename attention layers for VAE

* add shape asserts when renaming attention layers

* transpose VAE attention layers

* add pipeline flax SDXL code [WIP]

* continue add pipeline flax SDXL code [WIP]

* cleanup

* Working on JIT support

Fixed prompt embedding shapes so they work in parallel mode. Assuming we
always have both text encoders for now, for simplicity.

* Fixing embeddings (untested)

* Remove spurious line

* Shard guidance_scale when jitting.

* Decode images

* Fix sharding

* style

* Refiner UNet can be loaded.

* Refiner / img2img pipeline

* Allow latent outputs from base and latent inputs in refiner

This makes it possible to chain base + refiner without having to use the
vae decoder in the base model, the vae encoder in the refiner, skipping
conversions to/from PIL, and avoiding TPU <-> CPU memory copies.

* Adapt to FlaxCLIPTextModelOutput

* Update Flax XL pipeline to FlaxCLIPTextModelOutput

* make fix-copies

* make style

* add euler scheduler

* Fix import

* Fix copies, comment unused code.

* Fix SDXL Flax imports

* Fix euler discrete begin

* improve init import

* finish

* put discrete euler in init

* fix flax euler

* Fix more

* make style

* correct init

* correct init

* Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline

* correct pipelines

* finish

---------

Co-authored-by: Martin Müller <[email protected]>
Co-authored-by: patil-suraj <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants