diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 5d4c7429e43f..11a32a92aee8 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -24,7 +24,7 @@ from huggingface_hub.utils import validate_hf_hub_args from torch import nn -from ..models.embeddings import ImageProjection, MLPProjection, Resampler +from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( USE_PEFT_BACKEND, @@ -712,7 +712,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] cross_attention_dim = state_dict["proj.3.weight"].shape[0] - image_projection = MLPProjection( + image_projection = IPAdapterFullImageProjection( cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim ) @@ -730,7 +730,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): hidden_dims = state_dict["latents"].shape[2] heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 - image_projection = Resampler( + image_projection = IPAdapterPlusImageProjection( embed_dims=embed_dims, output_dims=output_dims, hidden_dims=hidden_dims, @@ -780,7 +780,7 @@ def _load_ip_adapter_weights(self, state_dict): num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] # Set encoder_hid_proj after loading ip_adapter weights, - # because `Resampler` also has `attn_processors`. + # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None # set ip-adapter cross-attention processors & load state_dict diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7e98f77baf26..293b751cb67d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -462,7 +462,7 @@ def forward(self, image_embeds: torch.FloatTensor): return image_embeds -class MLPProjection(nn.Module): +class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): super().__init__() from .attention import FeedForward @@ -621,29 +621,34 @@ def shape(x): return a[:, 0, :] # cls_token -class FourierEmbedder(nn.Module): - def __init__(self, num_freqs=64, temperature=100): - super().__init__() +def get_fourier_embeds_from_boundingbox(embed_dim, box): + """ + Args: + embed_dim: int + box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline + Returns: + [B x N x embed_dim] tensor of positional embeddings + """ + + batch_size, num_boxes = box.shape[:2] - self.num_freqs = num_freqs - self.temperature = temperature + emb = 100 ** (torch.arange(embed_dim) / embed_dim) + emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) + emb = emb * box.unsqueeze(-1) - freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) - freq_bands = freq_bands[None, None, None] - self.register_buffer("freq_bands", freq_bands, persistent=False) + emb = torch.stack((emb.sin(), emb.cos()), dim=-1) + emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) - def __call__(self, x): - x = self.freq_bands * x.unsqueeze(-1) - return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) + return emb -class PositionNet(nn.Module): +class GLIGENTextBoundingboxProjection(nn.Module): def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): super().__init__() self.positive_len = positive_len self.out_dim = out_dim - self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.fourier_embedder_dim = fourier_freqs self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy if isinstance(out_dim, tuple): @@ -692,7 +697,7 @@ def forward( masks = masks.unsqueeze(-1) # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C + xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C # learnable null embedding xyxy_null = self.null_position_feature.view(1, 1, -1) @@ -787,7 +792,7 @@ def forward(self, caption): return hidden_states -class Resampler(nn.Module): +class IPAdapterPlusImageProjection(nn.Module): """Resampler of IP-Adapter Plus. Args: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ddf533d3bd3b..623e4d88d564 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -32,10 +32,10 @@ ) from .embeddings import ( GaussianFourierProjection, + GLIGENTextBoundingboxProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, - PositionNet, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, @@ -615,7 +615,7 @@ def __init__( positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = PositionNet( + self.position_net = GLIGENTextBoundingboxProjection( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 112aa42323f9..7c9936a0bd4e 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -187,7 +187,7 @@ def __call__(self, x): return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) -class PositionNet(nn.Module): +class GLIGENTextBoundingboxProjection(nn.Module): def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8): super().__init__() self.positive_len = positive_len @@ -820,7 +820,7 @@ def __init__( positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = PositionNet( + self.position_net = GLIGENTextBoundingboxProjection( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py index 91d7357fd352..632e696392d8 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -730,7 +730,7 @@ def __call__( ) gligen_phrases = gligen_phrases[:max_objs] gligen_boxes = gligen_boxes[:max_objs] - # prepare batched input to the PositionNet (boxes, phrases, mask) + # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask) # Get tokens for phrases from pre-trained CLIPTokenizer tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device) # For the token, we use the same pre-trained text encoder diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 35ea33328c1d..0e2a4765d6ca 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -26,7 +26,7 @@ from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor -from diffusers.models.embeddings import ImageProjection, Resampler +from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model): # "image_proj" (ImageProjection layer weights) cross_attention_dim = model.config["cross_attention_dim"] - image_projection = Resampler( + image_projection = IPAdapterPlusImageProjection( embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4 )