Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

减少重复拷贝,修复BUG #699

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions ppdiffusers/ppdiffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from paddle import nn

from ..utils import USE_PEFT_BACKEND
from .activations import get_activation, FP32SiLU
from .activations import FP32SiLU, get_activation
from .lora import LoRACompatibleLinear


Expand Down Expand Up @@ -52,12 +52,11 @@ def get_timestep_embedding(
# scale embeddings
emb = scale * emb

# concat sine and cosine embeddings
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)

# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
else:
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)

# zero pad
if embedding_dim % 2 == 1:
Expand Down Expand Up @@ -136,7 +135,7 @@ def __init__(
interpolation_scale=1,
add_pos_embed=True,
data_format="NCHW",
pos_embed_max_size=None, # For SD3 cropping
pos_embed_max_size=None, # For SD3 cropping
):
super().__init__()

Expand All @@ -147,7 +146,12 @@ def __init__(
self.data_format = data_format

self.proj = nn.Conv2D(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias_attr=bias, data_format=data_format,
in_channels,
embed_dim,
kernel_size=(patch_size, patch_size),
stride=patch_size,
bias_attr=bias,
data_format=data_format,
)
if layer_norm:
norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False)
Expand Down Expand Up @@ -178,6 +182,7 @@ def __init__(
self.register_buffer(
"pos_embed", paddle.to_tensor(pos_embed).cast("float32").unsqueeze(0), persistable=persistent
)

def cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
Expand Down Expand Up @@ -215,7 +220,7 @@ def forward(self, latent):
if self.data_format == "NCHW":
latent = latent.flatten(2).transpose([0, 2, 1]) # BCHW -> BNC
else:
latent = latent.flatten(1, 2) # BHWC -> BNC
latent = latent.flatten(1, 2) # BHWC -> BNC
if self.layer_norm:
latent = self.norm(latent)

Expand Down Expand Up @@ -521,7 +526,6 @@ def forward(self, image_embeds: paddle.Tensor):
return image_embeds



class CombinedTimestepTextProjEmbeddings(nn.Layer):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
Expand All @@ -532,14 +536,15 @@ def __init__(self, embedding_dim, pooled_projection_dim):

def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb = self.timestep_embedder(timesteps_proj.cast(dtype=pooled_projection.dtype)) # (N, D)

pooled_projections = self.text_embedder(pooled_projection)

conditioning = timesteps_emb + pooled_projections

return conditioning


class CombinedTimestepLabelEmbeddings(nn.Layer):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
Expand Down Expand Up @@ -906,4 +911,4 @@ def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
return hidden_states
2 changes: 1 addition & 1 deletion ppdiffusers/ppdiffusers/patches/paddle_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def scaled_dot_product_attention_(
# (2) FLAG_USE_CUTLASS_V2 in yes, y, true, t, 1, use cutlass v2
use_cutlass_v2 = attn_mask is not None or str2bool(os.getenv("FLAG_USE_CUTLASS_V2", "no"))
if not use_cutlass_v2:
with requires_grad_and_without_random(query, key, value):
with requires_grad_and_without_random(query, key, value, stop_gradient=False):
output = memory_efficient_attention(
query,
key,
Expand Down