Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 8 additions & 27 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def forward(
return objs


class CombinedTimestepSizeEmbeddings(nn.Module):
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.

Expand All @@ -746,45 +746,27 @@ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool

self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.use_additional_conditions = True
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)

def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
if size.ndim == 1:
size = size[:, None]

if size.shape[0] != batch_size:
size = size.repeat(batch_size // size.shape[0], 1)
if size.shape[0] != batch_size:
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")

current_batch_size, dims = size.shape[0], size.shape[1]
size = size.reshape(-1)
size_freq = self.additional_condition_proj(size).to(size.dtype)

size_emb = embedder(size_freq)
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
return size_emb

def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)

if self.use_additional_conditions:
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
aspect_ratio = self.apply_condition(
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
)
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice refactor!

resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb

return conditioning


class CaptionProjection(nn.Module):
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.

Expand All @@ -796,9 +778,8 @@ def __init__(self, in_features, hidden_size, num_tokens=120):
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))

def forward(self, caption, force_drop_ids=None):
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F

from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings


class AdaLayerNorm(nn.Module):
Expand Down Expand Up @@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module):
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()

self.emb = CombinedTimestepSizeEmbeddings(
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock
from .embeddings import CaptionProjection, PatchEmbed
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
from .normalization import AdaLayerNormSingle
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(

self.caption_projection = None
if caption_channels is not None:
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might actually be worth breaking up Transformer2D up into a dedicated one for PixArt.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a future PR, yeah? I am happy to work on it once this is merged.

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely for a future PR

But I think we should refactor transformers and UNet after we clean up all the lower-level classes and make such decisions for all models/pipelines at once so it will be consistent

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes 100 percent!


self.gradient_checkpointing = False

Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,11 @@ def __call__(
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)

if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
Comment on lines +857 to +859
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a new addition?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really - it gets duplicated later inside embedding

if size.shape[0] != batch_size:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice <3

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
fast tests fail because there are some randomly initialized weights in some components. I think we need to put torch.manual_seed(0) before making each component e.g.

vae = AutoencoderKL()

should we open a new PR to only update the tests, and I rebase after that? I'm not comfortable updating tests directly from this PR since I updated the code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that wasn't the case before. Wonder what changed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}

# 7. Denoising loop
Expand Down