-
Notifications
You must be signed in to change notification settings - Fork 6.6k
SDXL flax #4254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
SDXL flax #4254
Changes from 14 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
cbdd6d6
support transformer_layers_per block in flax UNet
mar-muel 4c78659
add support for text_time additional embeddings to Flax UNet
mar-muel 3aa3164
rename attention layers for VAE
mar-muel bc267e4
add shape asserts when renaming attention layers
mar-muel ea0e675
transpose VAE attention layers
mar-muel 1cc6c37
add pipeline flax SDXL code [WIP]
mar-muel 9771feb
continue add pipeline flax SDXL code [WIP]
mar-muel 1137263
cleanup
mar-muel f02d795
Working on JIT support
pcuenca e6eaee0
Merge remote-tracking branch 'origin/main' into sdxl-flax
pcuenca ff46fa2
Fixing embeddings (untested)
pcuenca dfc3c81
Remove spurious line
pcuenca 484e516
Shard guidance_scale when jitting.
pcuenca 70e1058
Decode images
pcuenca 2620ff3
Merge remote-tracking branch 'origin/main' into sdxl-flax
pcuenca 8e63802
Fix sharding
pcuenca 4fd885a
style
pcuenca 27d26ca
Refiner UNet can be loaded.
pcuenca 61d93f4
Refiner / img2img pipeline
pcuenca 636d69f
Allow latent outputs from base and latent inputs in refiner
pcuenca 3a00bfe
Adapt to FlaxCLIPTextModelOutput
pcuenca a6442f7
Merge branch 'sdxl-flax' of github.com:huggingface/diffusers into sdx…
pcuenca cba4327
Update Flax XL pipeline to FlaxCLIPTextModelOutput
pcuenca a34ca77
Merge remote-tracking branch 'origin/main' into sdxl-flax
pcuenca 4411bef
make fix-copies
pcuenca 5dc45d8
make style
pcuenca dd64ffa
add euler scheduler
patil-suraj 3a3850f
Merge remote-tracking branch 'origin' into sdxl-flax
pcuenca 66fd712
Fix import
pcuenca 4ecbc3f
Fix copies, comment unused code.
pcuenca 0a5be23
Fix SDXL Flax imports
pcuenca 2bcf68d
Fix euler discrete begin
patrickvonplaten a9f3404
improve init import
patrickvonplaten cba9f04
finish
patrickvonplaten 2c8e77f
put discrete euler in init
patrickvonplaten 461be99
fix flax euler
patrickvonplaten a975fc0
Merge branch 'sdxl-flax' of https://github.com/huggingface/diffusers …
patrickvonplaten d02a50b
Fix more
patrickvonplaten 53d595a
fix
patrickvonplaten 3e1ffba
make style
patrickvonplaten e01da93
correct init
patrickvonplaten 4346ed3
correct init
patrickvonplaten da22045
Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline
pcuenca c503a36
correct pipelines
patrickvonplaten b52c56d
finish
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import Optional, Tuple, Union | ||
| from typing import Optional, Tuple, Union, Dict | ||
|
|
||
| import flax | ||
| import flax.linen as nn | ||
|
|
@@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): | |
| flip_sin_to_cos: bool = True | ||
| freq_shift: int = 0 | ||
| use_memory_efficient_attention: bool = False | ||
| transformer_layers_per_block: Union[int, Tuple[int]] = 1 | ||
| addition_embed_type: Optional[str] = None | ||
| addition_time_embed_dim: Optional[int] = None | ||
| addition_embed_type_num_heads: int = 64 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this num heads or |
||
| projection_class_embeddings_input_dim: Optional[int] = None | ||
|
|
||
| def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: | ||
| # init input tensors | ||
|
|
@@ -127,7 +132,13 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: | |
| params_rng, dropout_rng = jax.random.split(rng) | ||
| rngs = {"params": params_rng, "dropout": dropout_rng} | ||
|
|
||
| return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] | ||
| added_cond_kwargs = None | ||
| if self.addition_embed_type == 'text_time': | ||
| added_cond_kwargs = { | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| 'text_embeds': jnp.zeros((1, 1280), dtype=jnp.float32), # TODO: This should be set based on config | ||
| 'time_ids': jnp.zeros((1, 6), dtype=jnp.float32) | ||
| } | ||
| return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] | ||
|
|
||
| def setup(self): | ||
| block_out_channels = self.block_out_channels | ||
|
|
@@ -168,6 +179,22 @@ def setup(self): | |
| if isinstance(num_attention_heads, int): | ||
| num_attention_heads = (num_attention_heads,) * len(self.down_block_types) | ||
|
|
||
| # transformer layers per block | ||
| transformer_layers_per_block = self.transformer_layers_per_block | ||
| if isinstance(transformer_layers_per_block, int): | ||
| transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) | ||
|
|
||
| # addition embed types | ||
| if self.addition_embed_type is None: | ||
| self.add_embedding = None | ||
| elif self.addition_embed_type == 'text_time': | ||
| if self.addition_time_embed_dim is None: | ||
| raise ValueError(f'addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None') | ||
| self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) | ||
| self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) | ||
| else: | ||
| raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") | ||
|
|
||
| # down | ||
| down_blocks = [] | ||
| output_channel = block_out_channels[0] | ||
|
|
@@ -182,6 +209,7 @@ def setup(self): | |
| out_channels=output_channel, | ||
| dropout=self.dropout, | ||
| num_layers=self.layers_per_block, | ||
| transformer_layers_per_block=transformer_layers_per_block[i], | ||
| num_attention_heads=num_attention_heads[i], | ||
| add_downsample=not is_final_block, | ||
| use_linear_projection=self.use_linear_projection, | ||
|
|
@@ -207,6 +235,7 @@ def setup(self): | |
| in_channels=block_out_channels[-1], | ||
| dropout=self.dropout, | ||
| num_attention_heads=num_attention_heads[-1], | ||
| transformer_layers_per_block=transformer_layers_per_block[-1], | ||
| use_linear_projection=self.use_linear_projection, | ||
| use_memory_efficient_attention=self.use_memory_efficient_attention, | ||
| dtype=self.dtype, | ||
|
|
@@ -218,6 +247,7 @@ def setup(self): | |
| reversed_num_attention_heads = list(reversed(num_attention_heads)) | ||
| only_cross_attention = list(reversed(only_cross_attention)) | ||
| output_channel = reversed_block_out_channels[0] | ||
| reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) | ||
| for i, up_block_type in enumerate(self.up_block_types): | ||
| prev_output_channel = output_channel | ||
| output_channel = reversed_block_out_channels[i] | ||
|
|
@@ -231,6 +261,7 @@ def setup(self): | |
| out_channels=output_channel, | ||
| prev_output_channel=prev_output_channel, | ||
| num_layers=self.layers_per_block + 1, | ||
| transformer_layers_per_block=reversed_transformer_layers_per_block[i], | ||
| num_attention_heads=reversed_num_attention_heads[i], | ||
| add_upsample=not is_final_block, | ||
| dropout=self.dropout, | ||
|
|
@@ -269,6 +300,7 @@ def __call__( | |
| sample, | ||
| timesteps, | ||
| encoder_hidden_states, | ||
| added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, | ||
| down_block_additional_residuals=None, | ||
| mid_block_additional_residual=None, | ||
| return_dict: bool = True, | ||
|
|
@@ -300,6 +332,29 @@ def __call__( | |
| t_emb = self.time_proj(timesteps) | ||
| t_emb = self.time_embedding(t_emb) | ||
|
|
||
| # additional embeddings | ||
| aug_emb = None | ||
| if self.addition_embed_type == 'text_time': | ||
| if added_cond_kwargs is None: | ||
| raise ValueError(f'Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`') | ||
| text_embeds = added_cond_kwargs.get("text_embeds") | ||
| if text_embeds is None: | ||
| raise ValueError( | ||
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | ||
| ) | ||
| time_ids = added_cond_kwargs.get("time_ids") | ||
| if time_ids is None: | ||
| raise ValueError( | ||
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | ||
| ) | ||
| # compute time embeds | ||
| time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256) | ||
| time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1)) | ||
| add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1) | ||
| aug_emb = self.add_embedding(add_embeds) | ||
|
|
||
| t_emb = t_emb + aug_emb if aug_emb is not None else t_emb | ||
|
|
||
| # 2. pre-process | ||
| sample = jnp.transpose(sample, (0, 2, 3, 1)) | ||
| sample = self.conv_in(sample) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: verify dummy clases when transformers and/or flax are not installed.