-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[refactor embeddings]pixart-alpha #6212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| from ..models.embeddings import ImagePositionalEmbeddings | ||
| from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version | ||
| from .attention import BasicTransformerBlock | ||
| from .embeddings import CaptionProjection, PatchEmbed | ||
| from .embeddings import PatchEmbed, PixArtAlphaTextProjection | ||
| from .lora import LoRACompatibleConv, LoRACompatibleLinear | ||
| from .modeling_utils import ModelMixin | ||
| from .normalization import AdaLayerNormSingle | ||
|
|
@@ -235,7 +235,7 @@ def __init__( | |
|
|
||
| self.caption_projection = None | ||
| if caption_channels is not None: | ||
| self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) | ||
| self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might actually be worth breaking up Transformer2D up into a dedicated one for PixArt.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a future PR, yeah? I am happy to work on it once this is merged.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely for a future PR But I think we should refactor transformers and UNet after we clean up all the lower-level classes and make such decisions for all models/pipelines at once so it will be consistent
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes 100 percent! |
||
|
|
||
| self.gradient_checkpointing = False | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -853,6 +853,11 @@ def __call__( | |||||
| aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) | ||||||
| resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) | ||||||
| aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) | ||||||
|
|
||||||
| if do_classifier_free_guidance: | ||||||
| resolution = torch.cat([resolution, resolution], dim=0) | ||||||
| aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) | ||||||
|
Comment on lines
+857
to
+859
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a new addition?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really - it gets duplicated later inside embedding diffusers/src/diffusers/models/embeddings.py Line 758 in 6976cab
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice <3
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul
should we open a new PR to only update the tests, and I rebase after that? I'm not comfortable updating tests directly from this PR since I updated the code
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But that wasn't the case before. Wonder what changed.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} | ||||||
|
|
||||||
| # 7. Denoising loop | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice refactor!