-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support SD3 ControlNet and Multi-ControlNet. #8566
Conversation
@haofanwang 交给浩哥了❤️❤️ 我去准备权重和demo了 |
demo and weight are here https://huggingface.co/InstantX/SD3-Controlnet-Canny_alpha_512 import torch
from diffusers import StableDiffusion3Pipeline
from diffusers.models.controlnet_sd3 import ControlNetSD3Model
from diffusers.utils.torch_utils import randn_tensor
import sys, os
sys.path.append('/path/diffusers/examples/community')
from pipeline_stable_diffusion_3_controlnet import StableDiffusion3CommonPipeline
# load pipeline
base_model = 'stabilityai/stable-diffusion-3-medium-diffusers'
pipe = StableDiffusion3CommonPipeline.from_pretrained(
base_model,
controlnet_list=['InstantX/SD3-Controlnet-Canny_alpha_512']
)
pipe.to('cuda:0', torch.float16)
prompt = 'Anime style illustration of a girl wearing a suit. In the background we see a big rain approaching.'
n_prompt = 'NSFW, nude, naked, porn, ugly'
# controlnet config
controlnet_conditioning = [
dict(
control_index=0,
control_image=load_image('https://huggingface.co/InstantX/SD3-Controlnet-Canny_alpha_512/resolve/main/canny.jpg'),
control_weight=0.5,
control_pooled_projections='zeros'
)
]
# infer
image = pipe(
prompt=prompt,
negative_prompt=n_prompt,
controlnet_conditioning=controlnet_conditioning,
num_inference_steps=28,
guidance_scale=7.0,
height=512,
width=512,
latents=latents,
).images[0] |
Our teammate has implemented ControlNet for SD3 and trained a canny model for testing. Could you review this PR? @sayakpaul @yiyixuxu |
I would be in favor of supporting this through the core codebase actually. So, would like to first seek opinions from @yiyixuxu first before reviewing. In any case, I truly appreciate your hard work here! Solid! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a ton for the PR! 🔥🔥🔥🔥🔥🔥🔥🔥
agree with @sayakpaul here: we want to integrate this into core directly
Please add support for controlnet image2image pipeline something like StableDiffusionControlNetImg2ImgPipeline example from sd1.5:
|
@yiyixuxu let me know if you'd like me to review as well. |
@sayakpaul We will update soon based on comments above. Then you can review again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me!!
Left a few comments, I think we can merge very soon!
will also need tests and doc
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
Outdated
Show resolved
Hide resolved
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | ||
|
||
# 3. Prepare control image | ||
if isinstance(self.controlnet, SD3ControlNetModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so I liked that in the initial implementation, you made sure the controlnet is a MultiControlNetModel
regardless how many controllers are passed or in which format. you did this in __init__
:controlnet_list = MultiControlNetSD3Model(controlnet_list)
it is different from our code but an improvement I think, because now we only need to deal with MultiControlNetModel
. I think we can keep it as it is and refactor all our controlnet together this way in a follow-up PR:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to keep consistency with other ControlNet pipelines for now, and we can refactor them together later.
Great. It is a nice and hard work~ Can you also share the training code? |
in_channels=in_channels, | ||
embed_dim=self.inner_dim, | ||
pos_embed_type=None, | ||
# pos_embed_max_size=pos_embed_max_size, # hard-code for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also have a question regarding using to PatchEmbed
layers. Would love to understand this more and why one has pos_embed_max_size
has defined and another one has it None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they did not add a positional embedding here, so it does not need pos_embed_max_size
, but yeah, +1 on @sayakpaul 's questions - is there any reason we skip this for control input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they did not add a positional embedding here, so it does not need
pos_embed_max_size
, but yeah, +1 on @sayakpaul 's questions - is there any reason we skip this for control input?
yes, In the forward method of ControlNet, the position embedding only needs to be added once. Therefore, only one of the two PatchEmbed includes positional information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good 👍🏽 Could we please add a fast tests for the ControlNet model and ControlNetPipeline
@@ -40,7 +40,7 @@ | |||
"tensorboard": "tensorboard", | |||
"torch": "torch>=1.4", | |||
"torchvision": "torchvision", | |||
"transformers": "transformers>=4.25.1", | |||
"transformers": "transformers>=4.41.2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is reformatted automatically after make style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I mean this is not just a formatting change, it’s changing the version. We have had this change merged recently from another PR actually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@haofanwang I think you need to merge main (these change are already in main)
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False) | ||
controlnet.pos_embed_input.load_state_dict(transformer.pos_embed.state_dict(), strict=False) | ||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False) | ||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False) | ||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) | ||
|
||
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we perhaps just do controlnet.load_state_dict(transformer.state_dict(), strict=False)
or is it too risky?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we wouldn't prefer that
actually only need strict=False
for transformer_blocks
, no? @haofanwang, the other layers should be identical and be able to load without strict=False
. this way:
- when we read the code, we know which layers are identical, which layers are not
- get a better error message when the checkpoints are wrong
I don't feel too strongly about this, so ok if you just want to keep it as it is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that is better indeed. Thanks for explaining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we wouldn't prefer that
actually only need
strict=False
fortransformer_blocks
, no? @haofanwang, the other layers should be identical and be able to load withoutstrict=False
. this way:
- when we read the code, we know which layers are identical, which layers are not
- get a better error message when the checkpoints are wrong
I don't feel too strongly about this, so ok if you just want to keep it as it is
Among all the involved members, only the initialization of transformer_blocks
need requires strict=False
, and pos_embed_input
should be initialized with all zeros. controlnet.pos_embed_input.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
can be deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 @sayakpaul @yiyixuxu @stevhliu Added test and doc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good so far! Remember to add controlnet_sd3.md
to the toctree
in the models and pipelines sections so their docs get built 🙂
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel | ||
|
||
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny") | ||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers",controlnet=controlnet) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers",controlnet=controlnet) | |
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet) |
|
||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.* | ||
|
||
SD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can move this to the beginning so users reading it immediately know what this is. For example:
SD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3. The ControlNet model was introduced in ...
|
||
# ControlNet with Stable Diffusion 3 | ||
|
||
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also add something similar here:
StableDiffusion3ControlNetPipeline is an implementation of ControlNet for Stable Diffusion 3. ControlNet was introduced ...
|
||
class SD3MultiControlNetModel(ModelMixin): | ||
r""" | ||
Multiple `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiple `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet | |
`SD3ControlNetModel` wrapper class for Multi-SD3ControlNet. |
specific language governing permissions and limitations under the License. | ||
--> | ||
|
||
# SD3ControlNetModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to add the new doc pages to https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml
Done. Should be ready to merge:D |
@haofanwang have to run fix-copies again since we updated the encode_prompt on sd3 pipeline 😬 sorry! hope this is the last time! |
thank you! |
* sd3 controlnet --------- Co-authored-by: haofanwang <[email protected]>
What does this PR do?
demo and infer code
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.