diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index 0446fcbd9..0369e3bb0 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -19,7 +19,7 @@ from paddle import nn from ..utils import USE_PEFT_BACKEND -from .activations import get_activation, FP32SiLU +from .activations import FP32SiLU, get_activation from .lora import LoRACompatibleLinear @@ -52,12 +52,11 @@ def get_timestep_embedding( # scale embeddings emb = scale * emb - # concat sine and cosine embeddings - emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1) - # flip sine and cosine embeddings if flip_sin_to_cos: - emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) + else: + emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1) # zero pad if embedding_dim % 2 == 1: @@ -136,7 +135,7 @@ def __init__( interpolation_scale=1, add_pos_embed=True, data_format="NCHW", - pos_embed_max_size=None, # For SD3 cropping + pos_embed_max_size=None, # For SD3 cropping ): super().__init__() @@ -147,7 +146,12 @@ def __init__( self.data_format = data_format self.proj = nn.Conv2D( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias_attr=bias, data_format=data_format, + in_channels, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=patch_size, + bias_attr=bias, + data_format=data_format, ) if layer_norm: norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) @@ -178,6 +182,7 @@ def __init__( self.register_buffer( "pos_embed", paddle.to_tensor(pos_embed).cast("float32").unsqueeze(0), persistable=persistent ) + def cropped_pos_embed(self, height, width): """Crops positional embeddings for SD3 compatibility.""" if self.pos_embed_max_size is None: @@ -215,7 +220,7 @@ def forward(self, latent): if self.data_format == "NCHW": latent = latent.flatten(2).transpose([0, 2, 1]) # BCHW -> BNC else: - latent = latent.flatten(1, 2) # BHWC -> BNC + latent = latent.flatten(1, 2) # BHWC -> BNC if self.layer_norm: latent = self.norm(latent) @@ -521,7 +526,6 @@ def forward(self, image_embeds: paddle.Tensor): return image_embeds - class CombinedTimestepTextProjEmbeddings(nn.Layer): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() @@ -532,7 +536,7 @@ def __init__(self, embedding_dim, pooled_projection_dim): def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + timesteps_emb = self.timestep_embedder(timesteps_proj.cast(dtype=pooled_projection.dtype)) # (N, D) pooled_projections = self.text_embedder(pooled_projection) @@ -540,6 +544,7 @@ def forward(self, timestep, pooled_projection): return conditioning + class CombinedTimestepLabelEmbeddings(nn.Layer): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() @@ -906,4 +911,4 @@ def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/ppdiffusers/ppdiffusers/patches/paddle_patch.py b/ppdiffusers/ppdiffusers/patches/paddle_patch.py index 6845ad632..b55cbf138 100644 --- a/ppdiffusers/ppdiffusers/patches/paddle_patch.py +++ b/ppdiffusers/ppdiffusers/patches/paddle_patch.py @@ -429,7 +429,7 @@ def scaled_dot_product_attention_( # (2) FLAG_USE_CUTLASS_V2 in yes, y, true, t, 1, use cutlass v2 use_cutlass_v2 = attn_mask is not None or str2bool(os.getenv("FLAG_USE_CUTLASS_V2", "no")) if not use_cutlass_v2: - with requires_grad_and_without_random(query, key, value): + with requires_grad_and_without_random(query, key, value, stop_gradient=False): output = memory_efficient_attention( query, key,