Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SD3 ControlNet and Multi-ControlNet. #8566

Merged
merged 17 commits into from
Jun 19, 2024
Merged

Support SD3 ControlNet and Multi-ControlNet. #8566

merged 17 commits into from
Jun 19, 2024

Conversation

wangqixun
Copy link
Contributor

@wangqixun wangqixun commented Jun 15, 2024

What does this PR do?

  1. Support SD3 ControlNet.
  2. Support SD3 Multi-ControlNet
  3. A pipeline that supports SD3 Multi-ControlNet has been implemented.

demo and infer code

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@wangqixun
Copy link
Contributor Author

@haofanwang 交给浩哥了❤️❤️ 我去准备权重和demo了

@wangqixun
Copy link
Contributor Author

demo and weight are here

https://huggingface.co/InstantX/SD3-Controlnet-Canny_alpha_512

image
import torch
from diffusers import StableDiffusion3Pipeline
from diffusers.models.controlnet_sd3 import ControlNetSD3Model
from diffusers.utils.torch_utils import randn_tensor
import sys, os
sys.path.append('/path/diffusers/examples/community')
from pipeline_stable_diffusion_3_controlnet import StableDiffusion3CommonPipeline

# load pipeline
base_model = 'stabilityai/stable-diffusion-3-medium-diffusers'
pipe = StableDiffusion3CommonPipeline.from_pretrained(
    base_model, 
    controlnet_list=['InstantX/SD3-Controlnet-Canny_alpha_512']
)
pipe.to('cuda:0', torch.float16)

prompt = 'Anime style illustration of a girl wearing a suit. In the background we see a big rain approaching.'
n_prompt = 'NSFW, nude, naked, porn, ugly'

# controlnet config
controlnet_conditioning = [
    dict(
        control_index=0,
        control_image=load_image('https://huggingface.co/InstantX/SD3-Controlnet-Canny_alpha_512/resolve/main/canny.jpg'),
        control_weight=0.5,
        control_pooled_projections='zeros'
    )
]

# infer
image = pipe(
    prompt=prompt,
    negative_prompt=n_prompt,
    controlnet_conditioning=controlnet_conditioning,
    num_inference_steps=28,
    guidance_scale=7.0,
    height=512,
    width=512,
    latents=latents,
).images[0]

@haofanwang
Copy link
Contributor

Our teammate has implemented ControlNet for SD3 and trained a canny model for testing. Could you review this PR? @sayakpaul @yiyixuxu

@wangqixun
Copy link
Contributor Author

beta 1024-pixel canny model
image

@sayakpaul
Copy link
Member

I would be in favor of supporting this through the core codebase actually.

So, would like to first seek opinions from @yiyixuxu first before reviewing.

In any case, I truly appreciate your hard work here! Solid!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a ton for the PR! 🔥🔥🔥🔥🔥🔥🔥🔥
agree with @sayakpaul here: we want to integrate this into core directly

src/diffusers/models/controlnet_sd3.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_sd3.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_sd3.py Show resolved Hide resolved
src/diffusers/models/controlnet_sd3.py Show resolved Hide resolved
src/diffusers/models/transformers/transformer_sd3.py Outdated Show resolved Hide resolved
@s9anus98a
Copy link

Please add support for controlnet image2image pipeline something like StableDiffusionControlNetImg2ImgPipeline example from sd1.5:


pipe = StableDiffusionImg2ImgPipeline.from_pipe(
    pipe,
    custom_pipeline="jyoung105/sd15_perturbed_attention_guidance_i2i",
    torch_dtype=torch.float16
).to("cuda")

pipe = StableDiffusionControlNetImg2ImgPipeline.from_pipe(pipe,
                                                                    controlnet=controlnet,
                                                                    torch_dtype=torch.float16).to('cuda')
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

@sayakpaul
Copy link
Member

@yiyixuxu let me know if you'd like me to review as well.

@haofanwang
Copy link
Contributor

@sayakpaul We will update soon based on comments above. Then you can review again.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me!!
Left a few comments, I think we can merge very soon!

will also need tests and doc

src/diffusers/models/controlnet_sd3.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_sd3.py Outdated Show resolved Hide resolved
src/diffusers/models/controlnet_sd3.py Outdated Show resolved Hide resolved
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

# 3. Prepare control image
if isinstance(self.controlnet, SD3ControlNetModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I liked that in the initial implementation, you made sure the controlnet is a MultiControlNetModel regardless how many controllers are passed or in which format. you did this in __init__:controlnet_list = MultiControlNetSD3Model(controlnet_list)

it is different from our code but an improvement I think, because now we only need to deal with MultiControlNetModel. I think we can keep it as it is and refactor all our controlnet together this way in a follow-up PR:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to keep consistency with other ControlNet pipelines for now, and we can refactor them together later.

@yiyixuxu yiyixuxu requested review from DN6 and sayakpaul June 16, 2024 21:39
@appleyang123
Copy link

Great. It is a nice and hard work~ Can you also share the training code?

in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_type=None,
# pos_embed_max_size=pos_embed_max_size, # hard-code for now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also have a question regarding using to PatchEmbed layers. Would love to understand this more and why one has pos_embed_max_size has defined and another one has it None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they did not add a positional embedding here, so it does not need pos_embed_max_size, but yeah, +1 on @sayakpaul 's questions - is there any reason we skip this for control input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they did not add a positional embedding here, so it does not need pos_embed_max_size, but yeah, +1 on @sayakpaul 's questions - is there any reason we skip this for control input?

yes, In the forward method of ControlNet, the position embedding only needs to be added once. Therefore, only one of the two PatchEmbed includes positional information.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good 👍🏽 Could we please add a fast tests for the ControlNet model and ControlNetPipeline

@@ -40,7 +40,7 @@
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.25.1",
"transformers": "transformers>=4.41.2",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reformatted automatically after make style

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I mean this is not just a formatting change, it’s changing the version. We have had this change merged recently from another PR actually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 250 to 256
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
controlnet.pos_embed_input.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we perhaps just do controlnet.load_state_dict(transformer.state_dict(), strict=False) or is it too risky?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we wouldn't prefer that

actually only need strict=False for transformer_blocks, no? @haofanwang, the other layers should be identical and be able to load without strict=False. this way:

  1. when we read the code, we know which layers are identical, which layers are not
  2. get a better error message when the checkpoints are wrong

I don't feel too strongly about this, so ok if you just want to keep it as it is

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that is better indeed. Thanks for explaining.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we wouldn't prefer that

actually only need strict=False for transformer_blocks, no? @haofanwang, the other layers should be identical and be able to load without strict=False. this way:

  1. when we read the code, we know which layers are identical, which layers are not
  2. get a better error message when the checkpoints are wrong

I don't feel too strongly about this, so ok if you just want to keep it as it is

Among all the involved members, only the initialization of transformer_blocks need requires strict=False , and pos_embed_input should be initialized with all zeros. controlnet.pos_embed_input.load_state_dict(transformer.pos_embed.state_dict(), strict=False) can be deleted.

image

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good to me!

TODOs:

  • Docs (please ping @stevhliu once done)
  • Tests

@haofanwang
Copy link
Contributor

haofanwang commented Jun 18, 2024

@DN6 @sayakpaul @yiyixuxu @stevhliu

Added test and doc. make quality and make style passed locally.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good so far! Remember to add controlnet_sd3.md to the toctree in the models and pipelines sections so their docs get built 🙂

from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel

controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers",controlnet=controlnet)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers",controlnet=controlnet)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet)


*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*

SD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can move this to the beginning so users reading it immediately know what this is. For example:

SD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3. The ControlNet model was introduced in ...


# ControlNet with Stable Diffusion 3

ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add something similar here:

StableDiffusion3ControlNetPipeline is an implementation of ControlNet for Stable Diffusion 3. ControlNet was introduced ...


class SD3MultiControlNetModel(ModelMixin):
r"""
Multiple `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Multiple `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
`SD3ControlNetModel` wrapper class for Multi-SD3ControlNet.

specific language governing permissions and limitations under the License.
-->

# SD3ControlNetModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu
Copy link
Collaborator

I modified the fast tests here #8627
all relevant tests passed there - feel free to cherry-pick the last commit! 16e27b9

@haofanwang
Copy link
Contributor

Done. Should be ready to merge:D

@yiyixuxu
Copy link
Collaborator

@haofanwang have to run fix-copies again since we updated the encode_prompt on sd3 pipeline 😬 sorry! hope this is the last time!

@yiyixuxu yiyixuxu merged commit e5564d4 into huggingface:main Jun 19, 2024
14 of 15 checks passed
@yiyixuxu
Copy link
Collaborator

thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants