-
Notifications
You must be signed in to change notification settings - Fork 6.6k
SDXL flax #4254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDXL flax #4254
Changes from all commits
cbdd6d6
4c78659
3aa3164
bc267e4
ea0e675
1cc6c37
9771feb
1137263
f02d795
e6eaee0
ff46fa2
dfc3c81
484e516
70e1058
2620ff3
8e63802
4fd885a
27d26ca
61d93f4
636d69f
3a00bfe
a6442f7
cba4327
a34ca77
4411bef
5dc45d8
dd64ffa
3a3850f
66fd712
4ecbc3f
0a5be23
2bcf68d
a9f3404
cba9f04
2c8e77f
461be99
a975fc0
d02a50b
53d595a
3e1ffba
e01da93
4346ed3
da22045
c503a36
b52c56d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import Optional, Tuple, Union | ||
| from typing import Dict, Optional, Tuple, Union | ||
|
|
||
| import flax | ||
| import flax.linen as nn | ||
|
|
@@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): | |
| flip_sin_to_cos: bool = True | ||
| freq_shift: int = 0 | ||
| use_memory_efficient_attention: bool = False | ||
| transformer_layers_per_block: Union[int, Tuple[int]] = 1 | ||
| addition_embed_type: Optional[str] = None | ||
| addition_time_embed_dim: Optional[int] = None | ||
| addition_embed_type_num_heads: int = 64 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this num heads or |
||
| projection_class_embeddings_input_dim: Optional[int] = None | ||
|
|
||
| def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: | ||
| # init input tensors | ||
|
|
@@ -127,7 +132,17 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: | |
| params_rng, dropout_rng = jax.random.split(rng) | ||
| rngs = {"params": params_rng, "dropout": dropout_rng} | ||
|
|
||
| return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] | ||
| added_cond_kwargs = None | ||
| if self.addition_embed_type == "text_time": | ||
| # TODO: how to get this from the config? It's no longer cross_attention_dim | ||
| text_embeds_dim = 1280 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so it can be retrieved from
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, thanks!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pcuenca can we correct this? |
||
| time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim | ||
| time_ids_dims = time_ids_channels // self.addition_time_embed_dim | ||
| added_cond_kwargs = { | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32), | ||
| "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32), | ||
| } | ||
|
Comment on lines
+138
to
+144
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mar-muel this is how I'm computing the |
||
| return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] | ||
|
|
||
| def setup(self): | ||
| block_out_channels = self.block_out_channels | ||
|
|
@@ -168,6 +183,24 @@ def setup(self): | |
| if isinstance(num_attention_heads, int): | ||
| num_attention_heads = (num_attention_heads,) * len(self.down_block_types) | ||
|
|
||
| # transformer layers per block | ||
| transformer_layers_per_block = self.transformer_layers_per_block | ||
| if isinstance(transformer_layers_per_block, int): | ||
| transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) | ||
|
|
||
| # addition embed types | ||
| if self.addition_embed_type is None: | ||
| self.add_embedding = None | ||
| elif self.addition_embed_type == "text_time": | ||
| if self.addition_time_embed_dim is None: | ||
| raise ValueError( | ||
| f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None" | ||
| ) | ||
| self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) | ||
| self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) | ||
| else: | ||
| raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") | ||
|
|
||
| # down | ||
| down_blocks = [] | ||
| output_channel = block_out_channels[0] | ||
|
|
@@ -182,6 +215,7 @@ def setup(self): | |
| out_channels=output_channel, | ||
| dropout=self.dropout, | ||
| num_layers=self.layers_per_block, | ||
| transformer_layers_per_block=transformer_layers_per_block[i], | ||
| num_attention_heads=num_attention_heads[i], | ||
| add_downsample=not is_final_block, | ||
| use_linear_projection=self.use_linear_projection, | ||
|
|
@@ -207,6 +241,7 @@ def setup(self): | |
| in_channels=block_out_channels[-1], | ||
| dropout=self.dropout, | ||
| num_attention_heads=num_attention_heads[-1], | ||
| transformer_layers_per_block=transformer_layers_per_block[-1], | ||
| use_linear_projection=self.use_linear_projection, | ||
| use_memory_efficient_attention=self.use_memory_efficient_attention, | ||
| dtype=self.dtype, | ||
|
|
@@ -218,6 +253,7 @@ def setup(self): | |
| reversed_num_attention_heads = list(reversed(num_attention_heads)) | ||
| only_cross_attention = list(reversed(only_cross_attention)) | ||
| output_channel = reversed_block_out_channels[0] | ||
| reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) | ||
| for i, up_block_type in enumerate(self.up_block_types): | ||
| prev_output_channel = output_channel | ||
| output_channel = reversed_block_out_channels[i] | ||
|
|
@@ -231,6 +267,7 @@ def setup(self): | |
| out_channels=output_channel, | ||
| prev_output_channel=prev_output_channel, | ||
| num_layers=self.layers_per_block + 1, | ||
| transformer_layers_per_block=reversed_transformer_layers_per_block[i], | ||
| num_attention_heads=reversed_num_attention_heads[i], | ||
| add_upsample=not is_final_block, | ||
| dropout=self.dropout, | ||
|
|
@@ -269,6 +306,7 @@ def __call__( | |
| sample, | ||
| timesteps, | ||
| encoder_hidden_states, | ||
| added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, | ||
| down_block_additional_residuals=None, | ||
| mid_block_additional_residual=None, | ||
| return_dict: bool = True, | ||
|
|
@@ -300,6 +338,31 @@ def __call__( | |
| t_emb = self.time_proj(timesteps) | ||
| t_emb = self.time_embedding(t_emb) | ||
|
|
||
| # additional embeddings | ||
| aug_emb = None | ||
| if self.addition_embed_type == "text_time": | ||
| if added_cond_kwargs is None: | ||
| raise ValueError( | ||
| f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`" | ||
| ) | ||
| text_embeds = added_cond_kwargs.get("text_embeds") | ||
| if text_embeds is None: | ||
| raise ValueError( | ||
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | ||
| ) | ||
| time_ids = added_cond_kwargs.get("time_ids") | ||
| if time_ids is None: | ||
| raise ValueError( | ||
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | ||
| ) | ||
| # compute time embeds | ||
| time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256) | ||
| time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1)) | ||
| add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1) | ||
| aug_emb = self.add_embedding(add_embeds) | ||
|
|
||
| t_emb = t_emb + aug_emb if aug_emb is not None else t_emb | ||
|
|
||
| # 2. pre-process | ||
| sample = jnp.transpose(sample, (0, 2, 3, 1)) | ||
| sample = self.conv_in(sample) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a refactor I should have done a while ago