Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
cbdd6d6
support transformer_layers_per block in flax UNet
mar-muel Jul 11, 2023
4c78659
add support for text_time additional embeddings to Flax UNet
mar-muel Jul 12, 2023
3aa3164
rename attention layers for VAE
mar-muel Jul 12, 2023
bc267e4
add shape asserts when renaming attention layers
mar-muel Jul 13, 2023
ea0e675
transpose VAE attention layers
mar-muel Jul 14, 2023
1cc6c37
add pipeline flax SDXL code [WIP]
mar-muel Jul 17, 2023
9771feb
continue add pipeline flax SDXL code [WIP]
mar-muel Jul 18, 2023
1137263
cleanup
mar-muel Jul 18, 2023
f02d795
Working on JIT support
pcuenca Jul 25, 2023
e6eaee0
Merge remote-tracking branch 'origin/main' into sdxl-flax
pcuenca Jul 25, 2023
ff46fa2
Fixing embeddings (untested)
pcuenca Jul 25, 2023
dfc3c81
Remove spurious line
pcuenca Jul 25, 2023
484e516
Shard guidance_scale when jitting.
pcuenca Jul 25, 2023
70e1058
Decode images
pcuenca Jul 25, 2023
2620ff3
Merge remote-tracking branch 'origin/main' into sdxl-flax
pcuenca Jul 25, 2023
8e63802
Fix sharding
pcuenca Jul 28, 2023
4fd885a
style
pcuenca Jul 28, 2023
27d26ca
Refiner UNet can be loaded.
pcuenca Jul 30, 2023
61d93f4
Refiner / img2img pipeline
pcuenca Aug 1, 2023
636d69f
Allow latent outputs from base and latent inputs in refiner
pcuenca Aug 4, 2023
3a00bfe
Adapt to FlaxCLIPTextModelOutput
pcuenca Aug 5, 2023
a6442f7
Merge branch 'sdxl-flax' of github.com:huggingface/diffusers into sdx…
pcuenca Aug 5, 2023
cba4327
Update Flax XL pipeline to FlaxCLIPTextModelOutput
pcuenca Aug 5, 2023
a34ca77
Merge remote-tracking branch 'origin/main' into sdxl-flax
pcuenca Aug 8, 2023
4411bef
make fix-copies
pcuenca Aug 8, 2023
5dc45d8
make style
pcuenca Aug 8, 2023
dd64ffa
add euler scheduler
patil-suraj Aug 16, 2023
3a3850f
Merge remote-tracking branch 'origin' into sdxl-flax
pcuenca Sep 19, 2023
66fd712
Fix import
pcuenca Sep 19, 2023
4ecbc3f
Fix copies, comment unused code.
pcuenca Sep 19, 2023
0a5be23
Fix SDXL Flax imports
pcuenca Sep 21, 2023
2bcf68d
Fix euler discrete begin
patrickvonplaten Sep 21, 2023
a9f3404
improve init import
patrickvonplaten Sep 22, 2023
cba9f04
finish
patrickvonplaten Sep 22, 2023
2c8e77f
put discrete euler in init
patrickvonplaten Sep 22, 2023
461be99
fix flax euler
patrickvonplaten Sep 22, 2023
a975fc0
Merge branch 'sdxl-flax' of https://github.com/huggingface/diffusers …
patrickvonplaten Sep 22, 2023
d02a50b
Fix more
patrickvonplaten Sep 22, 2023
53d595a
fix
patrickvonplaten Sep 22, 2023
3e1ffba
make style
patrickvonplaten Sep 22, 2023
e01da93
correct init
patrickvonplaten Sep 22, 2023
4346ed3
correct init
patrickvonplaten Sep 22, 2023
da22045
Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline
pcuenca Sep 22, 2023
c503a36
correct pipelines
patrickvonplaten Sep 22, 2023
b52c56d
finish
patrickvonplaten Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@
"FlaxDDIMScheduler",
"FlaxDDPMScheduler",
"FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler",
"FlaxLMSDiscreteScheduler",
"FlaxPNDMScheduler",
Expand Down Expand Up @@ -395,6 +396,7 @@
"FlaxStableDiffusionImg2ImgPipeline",
"FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline",
"FlaxStableDiffusionXLPipeline",
]
)

Expand Down Expand Up @@ -673,6 +675,7 @@
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
Expand All @@ -691,6 +694,7 @@
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline,
)

try:
Expand Down
18 changes: 17 additions & 1 deletion src/diffusers/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,25 @@ def rename_key(key):
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""

# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)

# rename attention layers
if len(pt_tuple_key) > 1:
for rename_from, rename_to in (
("to_out_0", "proj_attn"),
("to_k", "key"),
("to_v", "value"),
("to_q", "query"),
):
if pt_tuple_key[-2] == rename_from:
weight_name = pt_tuple_key[-1]
weight_name = "kernel" if weight_name == "weight" else weight_name
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T

if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
Expand Down
34 changes: 17 additions & 17 deletions src/diffusers/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,23 @@ def from_pretrained(
"framework": "flax",
}

# Load config if we don't provide a configuration
config_path = config if config is not None else pretrained_model_name_or_path
model, model_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
# model args
dtype=dtype,
**kwargs,
)
# Load config if we don't provide one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a refactor I should have done a while ago

if config is None:
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)

model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)

# Load model
pretrained_path_with_subfolder = (
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/models/unet_2d_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
resnets = []
Expand All @@ -72,7 +73,7 @@ def setup(self):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
Expand Down Expand Up @@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
resnets = []
Expand All @@ -213,7 +215,7 @@ def setup(self):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
Expand Down Expand Up @@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
# there is always at least one resnet
Expand All @@ -350,7 +353,7 @@ def setup(self):
in_channels=self.in_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,6 @@ def forward(
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
Expand Down
67 changes: 65 additions & 2 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import flax
import flax.linen as nn
Expand Down Expand Up @@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this num heads or "head_dim"?

projection_class_embeddings_input_dim: Optional[int] = None

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
Expand All @@ -127,7 +132,17 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# TODO: how to get this from the config? It's no longer cross_attention_dim
text_embeds_dim = 1280
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@patrickvonplaten patrickvonplaten Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it can be retrieved from text_encoder_2.config.projection_dim

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca can we correct this?

time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
}
Comment on lines +138 to +144
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mar-muel this is how I'm computing the time_ids dims. text_embeds_dim is still hardcoded; not sure where to get that from.

return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]

def setup(self):
block_out_channels = self.block_out_channels
Expand Down Expand Up @@ -168,6 +183,24 @@ def setup(self):
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)

# transformer layers per block
transformer_layers_per_block = self.transformer_layers_per_block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)

# addition embed types
if self.addition_embed_type is None:
self.add_embedding = None
elif self.addition_embed_type == "text_time":
if self.addition_time_embed_dim is None:
raise ValueError(
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
)
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
else:
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")

# down
down_blocks = []
output_channel = block_out_channels[0]
Expand All @@ -182,6 +215,7 @@ def setup(self):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
Expand All @@ -207,6 +241,7 @@ def setup(self):
in_channels=block_out_channels[-1],
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
Expand All @@ -218,6 +253,7 @@ def setup(self):
reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
Expand All @@ -231,6 +267,7 @@ def setup(self):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block,
dropout=self.dropout,
Expand Down Expand Up @@ -269,6 +306,7 @@ def __call__(
sample,
timesteps,
encoder_hidden_states,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
return_dict: bool = True,
Expand Down Expand Up @@ -300,6 +338,31 @@ def __call__(
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)

# additional embeddings
aug_emb = None
if self.addition_embed_type == "text_time":
if added_cond_kwargs is None:
raise ValueError(
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if text_embeds is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
if time_ids is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
# compute time embeds
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
aug_emb = self.add_embedding(add_embeds)

t_emb = t_emb + aug_emb if aug_emb is not None else t_emb

# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
Expand Down
Loading