Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline
)

try:
Expand Down
13 changes: 12 additions & 1 deletion src/diffusers/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,20 @@ def rename_key(key):
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""

# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)

# rename attention layers
if len(pt_tuple_key) > 1:
for rename_from, rename_to in (('to_out_0', 'proj_attn'), ('to_k', 'key'), ('to_v', 'value'), ('to_q', 'query')):
if pt_tuple_key[-2] == rename_from:
weight_name = pt_tuple_key[-1]
weight_name = 'kernel' if weight_name == 'weight' else weight_name
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T

if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/models/unet_2d_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
resnets = []
Expand All @@ -72,7 +73,7 @@ def setup(self):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
Expand Down Expand Up @@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
resnets = []
Expand All @@ -213,7 +215,7 @@ def setup(self):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
Expand Down Expand Up @@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
# there is always at least one resnet
Expand All @@ -350,7 +353,7 @@ def setup(self):
in_channels=self.in_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
addition_embed_type_num_heads = 64,
):
super().__init__()

Expand Down Expand Up @@ -848,7 +848,6 @@ def forward(
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
Expand Down
59 changes: 57 additions & 2 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Dict

import flax
import flax.linen as nn
Expand Down Expand Up @@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
Expand All @@ -127,7 +132,13 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
added_cond_kwargs = None
if self.addition_embed_type == 'text_time':
added_cond_kwargs = {
'text_embeds': jnp.zeros((1, 1280), dtype=jnp.float32), # TODO: This should be set based on config
'time_ids': jnp.zeros((1, 6), dtype=jnp.float32)
}
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]

def setup(self):
block_out_channels = self.block_out_channels
Expand Down Expand Up @@ -168,6 +179,22 @@ def setup(self):
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)

# transformer layers per block
transformer_layers_per_block = self.transformer_layers_per_block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)

# addition embed types
if self.addition_embed_type is None:
self.add_embedding = None
elif self.addition_embed_type == 'text_time':
if self.addition_time_embed_dim is None:
raise ValueError(f'addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None')
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
else:
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")

# down
down_blocks = []
output_channel = block_out_channels[0]
Expand All @@ -182,6 +209,7 @@ def setup(self):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
Expand All @@ -207,6 +235,7 @@ def setup(self):
in_channels=block_out_channels[-1],
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
Expand All @@ -218,6 +247,7 @@ def setup(self):
reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
Expand All @@ -231,6 +261,7 @@ def setup(self):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block,
dropout=self.dropout,
Expand Down Expand Up @@ -269,6 +300,7 @@ def __call__(
sample,
timesteps,
encoder_hidden_states,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
return_dict: bool = True,
Expand Down Expand Up @@ -300,6 +332,29 @@ def __call__(
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)

# additional embeddings
aug_emb = None
if self.addition_embed_type == 'text_time':
if added_cond_kwargs is None:
raise ValueError(f'Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`')
text_embeds = added_cond_kwargs.get("text_embeds")
if text_embeds is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
if time_ids is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
# compute time embeds
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
aug_emb = self.add_embedding(add_embeds)

t_emb = t_emb + aug_emb if aug_emb is not None else t_emb

# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
)
from .stable_diffusion_xl import (
FlaxStableDiffusionXLPipeline,
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
Expand Down
13 changes: 12 additions & 1 deletion src/diffusers/pipelines/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import numpy as np
import PIL

from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available
from ...utils import (
BaseOutput,
is_invisible_watermark_available,
is_torch_available,
is_transformers_available,
is_flax_available,
)


@dataclass
Expand All @@ -24,3 +30,8 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline


if is_flax_available():
from .pipeline_flax_stable_diffusion_xl import FlaxStableDiffusionXLPipeline

Loading