From 93e397843c34e28b6a66ab1ab57ac0881e65b07c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 18 Dec 2023 10:00:54 +0000 Subject: [PATCH] pixart-alpha --- src/diffusers/models/embeddings.py | 35 +++++-------------- src/diffusers/models/normalization.py | 4 +-- src/diffusers/models/transformer_2d.py | 4 +-- .../pixart_alpha/pipeline_pixart_alpha.py | 5 +++ 4 files changed, 17 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 73abc9869230..db68591bdb44 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -729,7 +729,7 @@ def forward( return objs -class CombinedTimestepSizeEmbeddings(nn.Module): +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. @@ -746,45 +746,27 @@ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool self.use_additional_conditions = use_additional_conditions if use_additional_conditions: - self.use_additional_conditions = True self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): - if size.ndim == 1: - size = size[:, None] - - if size.shape[0] != batch_size: - size = size.repeat(batch_size // size.shape[0], 1) - if size.shape[0] != batch_size: - raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") - - current_batch_size, dims = size.shape[0], size.shape[1] - size = size.reshape(-1) - size_freq = self.additional_condition_proj(size).to(size.dtype) - - size_emb = embedder(size_freq) - size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) - return size_emb - def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) if self.use_additional_conditions: - resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) - aspect_ratio = self.apply_condition( - aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder - ) - conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) else: conditioning = timesteps_emb return conditioning -class CaptionProjection(nn.Module): +class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -796,9 +778,8 @@ def __init__(self, in_features, hidden_size, num_tokens=120): self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) self.act_1 = nn.GELU(approximate="tanh") self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) - self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5)) - def forward(self, caption, force_drop_ids=None): + def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 11d2a344744e..25af4d853b86 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from .activations import get_activation -from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings +from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings class AdaLayerNorm(nn.Module): @@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module): def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): super().__init__() - self.emb = CombinedTimestepSizeEmbeddings( + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions ) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 3aecc43f0f5b..128395cc161a 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -22,7 +22,7 @@ from ..models.embeddings import ImagePositionalEmbeddings from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version from .attention import BasicTransformerBlock -from .embeddings import CaptionProjection, PatchEmbed +from .embeddings import PatchEmbed, PixArtAlphaTextProjection from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin from .normalization import AdaLayerNormSingle @@ -235,7 +235,7 @@ def __init__( self.caption_projection = None if caption_channels is not None: - self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) self.gradient_checkpointing = False diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 090b66915dd0..82a170400068 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -853,6 +853,11 @@ def __call__( aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} # 7. Denoising loop