-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GLIGEN implementation #4441
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@patrickvonplaten @sayakpaul @stevhliu |
|
@sayakpaul : Thank you. All checks passed. |
Very awesome! |
@@ -63,6 +63,62 @@ class UNet2DConditionOutput(BaseOutput): | |||
sample: torch.FloatTensor = None | |||
|
|||
|
|||
class FourierEmbedder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) | ||
|
||
|
||
class PositionNet(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -202,6 +258,7 @@ def __init__( | |||
conv_in_kernel: int = 3, | |||
conv_out_kernel: int = 3, | |||
projection_class_embeddings_input_dim: Optional[int] = None, | |||
use_gated_attention: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_gated_attention: bool = False, | |
attention_type: str = "default", # gated |
When introducing new config variables let's make sure we can extend them going forward. Using a string-type variable would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When introducing new config variables let's make sure we can extend them going forward. Using a string-type variable would be nice
@sayakpaul to fix this comment, I had to create a new repo for weights so that I can modify the unet/config file
Weight of the model are the exact copy of original weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should then submit PRs to the original model repository and tag the authors there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do that, but it will break their fork of diffusers. I'm not sure if they would prefer that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, IIUC the existing checkpoints from the gligen organization won't work with the current implementation that is being added in the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely agree with you. I've reached out to their author haotian-liu at [email protected] to checkout this PR, but haven't heard back from them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe open a discussion on their model repository? Feel free to tag me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened a discussion here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright thanks much!
Meanwhile, I think we can knock off the other pending comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Will fix the pending comments today
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me in general! @sayakpaul @yiyixuxu do you want to give this a pass?
Yes, I will. |
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
|
||
class StableDiffusionGLIGENPipeline(DiffusionPipeline): | ||
r""" | ||
Pipeline for text-to-image generation using Stable Diffusion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to change.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
@nikhil-masterful till we're hearing back from the GLIGEN authors, I think it's okay to have all the checkpoints under your HF profile with the latest configs. WDYT? |
Agreed |
@sayakpaul @yiyixuxu : Added a FastTest as requested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking fantastic! Thanks so much for iterating! I guess the only remainings are:
- https://github.com/huggingface/diffusers/pull/4441/files#r1289496701
- Add GLIGEN implementation #4441 (comment)
Let me know if anything is unclear.
|
Works perfect! @stevhliu is this how you would have expected to see multiple example use cases for a pipeline to be included in the corresponding doc?
This is about having clones of the original GLIGEN checkpoints under your HF profile with the new configuration change. This will allow users to use all the available GLIGEN checkpoints directly from Also, instead of doing this, could we update the examples to something like so? from diffusers.utils import make_image_grid
images = pipe(
prompt=prompt,
num_images_per_prompt=1,
gligen_phrases=phrases,
gligen_boxes=boxes,
gligen_scheduled_sampling_beta=1,
num_inference_steps=50,
).images
make_image_grid(images, rows, cols, resize=256) It's simpler and doesn't make use of additional dependencies like |
Yeah since there aren't separate pipelines (for example, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for iterating! It looks much better now.
I left some more comments here :)
tests/pipelines/stable_diffusion/test_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
tests/pipelines/stable_diffusion/test_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
width = width or self.unet.config.sample_size * self.vae_scale_factor | ||
|
||
# 1. Check inputs. Raise error if not correct | ||
self.check_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move the logics to check these three parameters into this function instead?
`gligen_phrases`, `gligen_boxes`, `gligen_inpaint_image`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved gligen_phrases
, gligen_boxes
to check_inputs
gligen_inpaint_image
is always valid because None
is also a valid input for text2img
case
dtype = self.text_encoder.dtype | ||
# For each entity, described in phrases, is denoted with a bounding box, | ||
# we represent the location information as (xmin,ymin,xmax,ymax) | ||
boxes = torch.zeros(max_objs, 4, device=device, dtype=dtype) | ||
boxes[:n_objs] = torch.tensor(_boxes[:n_objs]) | ||
text_embeddings = torch.zeros(max_objs, self.unet.cross_attention_dim, device=device, dtype=dtype) | ||
text_embeddings[:n_objs] = _text_embeddings[:n_objs] | ||
# Generate a mask for each object that is entity described by phrases | ||
masks = torch.zeros(max_objs, device=device, dtype=dtype) | ||
masks[:n_objs] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 questions here:
- is there any reason
_text_embedding
,gligen_boxes
would not have lengthn_objs
? I made my suggestions based on the assumption that_text_embedding
,gligen_boxes
is already at lengthn_obj
but let me know if it's not the case - do
boxes
,text_embeddings
has to be a fixes sizemax_obj
? can't it just be the same size as the number of objects we passed? so we don't have to fill the rest of of tensor0
s
dtype = self.text_encoder.dtype | |
# For each entity, described in phrases, is denoted with a bounding box, | |
# we represent the location information as (xmin,ymin,xmax,ymax) | |
boxes = torch.zeros(max_objs, 4, device=device, dtype=dtype) | |
boxes[:n_objs] = torch.tensor(_boxes[:n_objs]) | |
text_embeddings = torch.zeros(max_objs, self.unet.cross_attention_dim, device=device, dtype=dtype) | |
text_embeddings[:n_objs] = _text_embeddings[:n_objs] | |
# Generate a mask for each object that is entity described by phrases | |
masks = torch.zeros(max_objs, device=device, dtype=dtype) | |
masks[:n_objs] = 1 | |
# For each entity, described in phrases, is denoted with a bounding box, | |
# we represent the location information as (xmin,ymin,xmax,ymax) | |
boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) | |
boxes[:n_objs] = torch.tensor(gligen_boxes) | |
text_embeddings = torch.zeros(max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype) | |
text_embeddings[:n_objs] = _text_embeddings | |
# Generate a mask for each object that is entity described by phrases | |
masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) | |
masks[:n_objs] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed _text_embedding
, gligen_boxes
to have length n_objs
.
Do boxes, text_embeddings has to be a fixes size max_obj? yes, that's how GLIGEN authors intended it to be. It would be good to keep it that way for now
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
Outdated
Show resolved
Hide resolved
|
@nikhil-masterful I only see two checkpoints here: https://huggingface.co/masterful. Are these the only two supported from GLIGEN officially? If so, yeah, then that's checked. @yiyixuxu could you give this a final look? |
Yes, those were the only two supported from GLIGEN officially |
@sayakpaul @yiyixuxu Thanks for reviewing this. I've fixed all the outstanding comments. Please let me know if I missed anything |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great to me! thanks!
@sayakpaul if it looks good, can we merge please ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for iterating!
Thanks for helping me make this contribution. It was great experience. @sayakpaul I would like to continue contributing to diffusers. Could you please direct me to any outstanding bug/feature that I can work on ? I can iterate on things faster. |
Thanks so much for being willing to do that! I would redirect you to our issues thread and see what interests you and we can take it from there. |
* Add GLIGEN implementation * GLIGEN: Fix code quality check failures * GLIGEN: Fix Import block un-sorted or un-formatted failures * GLIGEN: Fix check_repository_consistency failures * GLIGEN: Add 'PositionNet' to versatile_diffusion/modeling_text_unet.py * GLIGEN: check_repository_consistency: fix 'copy does not match' error * GLIGEN: Fix review comments (1) * GLIGEN: Fix E721 Do not compare types, use `isinstance()` failures * GLIGEN : Ensure _encode_prompt() copy matches to StableDiffusionPipeline * GLIGEN: Fix ruff E721 failure in unidiffuser/test_unidiffuser.py * GLIGEN: doc_builder: restyle pipeline_stable_diffusion_gligen.py * GIGLEN: reset files unrelated to gligen * GLIGEN: Fix documentation comments (1) * GLIGEN: Fix review comments (2) * GLIGEN: Added FastTest * GLIGEN: Fix review comments (3)
* Add GLIGEN implementation * GLIGEN: Fix code quality check failures * GLIGEN: Fix Import block un-sorted or un-formatted failures * GLIGEN: Fix check_repository_consistency failures * GLIGEN: Add 'PositionNet' to versatile_diffusion/modeling_text_unet.py * GLIGEN: check_repository_consistency: fix 'copy does not match' error * GLIGEN: Fix review comments (1) * GLIGEN: Fix E721 Do not compare types, use `isinstance()` failures * GLIGEN : Ensure _encode_prompt() copy matches to StableDiffusionPipeline * GLIGEN: Fix ruff E721 failure in unidiffuser/test_unidiffuser.py * GLIGEN: doc_builder: restyle pipeline_stable_diffusion_gligen.py * GIGLEN: reset files unrelated to gligen * GLIGEN: Fix documentation comments (1) * GLIGEN: Fix review comments (2) * GLIGEN: Added FastTest * GLIGEN: Fix review comments (3)
What does this PR do?
GLIGEN: Open-Set Grounded Text-to-Image Generation (CVPR 2023)
Project page - https://gligen.github.io/
Paper - https://arxiv.org/abs/2301.07093
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.