Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve control net block index for sd3 #9758

Merged
merged 16 commits into from
Nov 20, 2024

Conversation

linjiapro
Copy link
Contributor

@linjiapro linjiapro commented Oct 23, 2024

What does this PR do?

The layer configuration for the Control Net in Stable Diffusion 3 models must adhere to the rule that the total number of layers in the SD3 model should be a multiple of the Control Net's layer count.

For SD3.5, which has 38 layers, the Control Net can only have three possible options: 2, 19, or 38. This leads to inefficiencies in the setup.

Also, qk_norm, context_pre_only_last_layer, dual_attention_layers are added to match the transformer architecture

Who can review?

@sayakpaul @yiyixuxu @DN6

@linjiapro
Copy link
Contributor Author

cc @sayakpaul @yiyixuxu @DN6

@linjiapro
Copy link
Contributor Author

This is a very simple PR, but for some reason, the tests all failed for wired reasons such as:

Unable to find self-hosted runner group: 'aws-general-8-plus'.

@sayakpaul
Copy link
Member

Thanks for your contributions! Could you maybe also add a test for this?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linjiapro
Copy link
Contributor Author

@sayakpaul

There is an existing test for the pipeline of SD3 control net, I leveraged that. The layer number for the control net changed from 1 to 3, the number of layer of the transformer (4) is no longer the multiples of the number of layers of control net. This will test the code changes of this PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I left some comments.

.gitignore Outdated
@@ -102,6 +102,7 @@ venv/
ENV/
env.bak/
venv.bak/
myenv/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -77,7 +77,7 @@ def get_dummy_components(self):
sample_size=32,
patch_size=1,
in_channels=8,
num_layers=1,
num_layers=3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change this value here I think. Instead, we could make this method accept an argument like num_controlnet_layers and then leverage it as needed. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@linjiapro
Copy link
Contributor Author

linjiapro commented Nov 15, 2024

@sayakpaul @yiyixuxu, can we take a look at this? Thanks

@bghira
Copy link
Contributor

bghira commented Nov 16, 2024

maybe @DN6 is more active.

@@ -344,7 +345,8 @@ def custom_forward(*inputs):

# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we making this change? it is not the same so a breaking change, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu Good question.

The revised code adapts the strategy used by ControlNet for Flux, introducing a significant improvement in flexibility. Here's why this change matters:

In the old code, the number of transformer layers is divisible by the number of ControlNet layers. For example, with SD3.5 Large, which has 38 transformer layers, there were only two valid options for the number of ControlNet layers: 2 and 19. Setting the number of ControlNet layers to anything else, such as 5, would cause the old code to crash.

However, the Flux ControlNet approach removes this restriction, allowing greater flexibility in choosing the number of layers. The revised logic essentially mirrors the Flux implementation, enabling more versatile configurations.

Importantly, the new code maintains compatibility with existing setups. If the number of transformer layers is divisible by the number of ControlNet layers, the interval_control remains unchanged, ensuring all previous configurations continue to function seamlessly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I think it's indeed better, I'm just wondering if it would cause issue for controlnet is trained with the current logic
cc @haofanwang here

Copy link
Contributor Author

@linjiapro linjiapro Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will cause any issue with the trained controlnet using old code before this PR.

The reason is that for the controlnet to be trained with the old code, the number of layers of the transformer has to be divisible by the number of layers of the controlnet, and the new logic after this PR does not change the behavior for the above scenario.

@yiyixuxu
Copy link
Collaborator

can you run make style and make fix-copies?

@yiyixuxu yiyixuxu merged commit 1235862 into huggingface:main Nov 20, 2024
13 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants