Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Dec 14, 2023

This PR spin out ResnetBlockCondNorm2D from ResnetBlock2D: ResnetBlockCondNorm2D is a resnet block with a normalization layer that incorporates conditional information, e.g. AdaGroupNorm, SpatialNorm. This resnet block is mainly used by Kandinsky decoder movq and latent upscaler

@yiyixuxu
Copy link
Collaborator Author

still wip, cc @DN6 for awareness

# there is always at least one resnet
resnets = [
ResnetBlock2D(
get_resnet_block(
Copy link
Collaborator

@DN6 DN6 Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to just append the appropriate Resnet block here than introduce the block fetching function. Is the purpose here to ensure backwards compatibility? Or do these blocks use a mix of Resnet classes?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware that some of the blocks use a mix, i.e. the ones Kandinsky used for MOVQ: DownEncoderBlock2D, AttnDownEncoderBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D.

We could use get_resnet_block for these blocks instead - I would prefer that way too, because it is more explicit and readable. I used the function everywhere for now because:

  1. Technically all these blocks can be configured to use both types, I don't have the complete knowledge to know if it would break existing models. I think pretty unlikely for "spatial norm" though it is very specific to MOVQ
  2. use this as an example to show how we can ensure backward compatibility. I only split ResnetBlock2D into 2 blocks here in the draft and as far as I'm aware, ResnetBlockCondNorm2D is only used by latent upscaler and Kandinsky; but maybe we want to have more blocks with more specific configurations. It can become unmanageable without a block-fetching function

I agree with you it is better just to use the appropriate Resnet blocks, and very happy to refactor later. I will wait to do that after hearing more about:

  1. feedbacks on how many blocks we want to split into
  2. what's the best strategy to ensure backward compatibility, given that we might not have complete data on how these blocks are used

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feedbacks on how many blocks we want to split into

I think we'd want to split into the most common block types. I think this number is still quite unclear.

We could use get_resnet_block for these blocks instead - I would prefer that way too, because it is more explicit and readable.

How about we use the fetcher function but also add a comment above to denote which type of ResNet block is being used?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice start here. @yiyixuxu

What do we think about keeping resnet a dedicated module instead like we did in #6129? This way specific ResNet blocks can have their own script. This way resnet.py becomes quite clean IMO.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice first attempt here! I agree with @DN6 that creating a get_resnet_block(...) is not really making the code much more readable. Instead it would be better if we can directly add the correct Resnet class to the blocks.

I think most blocks only use one of the two ResNets - in this case can we add the correct ResNet right away? In the exceptional case where the block can use both ResNets, let's add an if-else statement?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator Author

@sayakpaul

What do we think about keeping resnet a dedicated module instead like we did in #6129? This way specific ResNet blocks can have their own script. This way resnet.py becomes quite clean IMO.

can we do these in future PRs? I would rather save these types of tasks for the last, i.e split one files into multiple, move them around etc

@yiyixuxu
Copy link
Collaborator Author

@patrickvonplaten @DN6
I removed get_resnet_block, let me know if you have any other feedbacks!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice looking good - got only one comment regarding the deprecation warning

Co-authored-by: Patrick von Platen <[email protected]>
Comment on lines 251 to 262
deprecate(
"ada_group",
"1.0.0",
"Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead",
)
if time_embedding_norm == "spatial":
raise ValueError(
"spatial",
"1.0.0",
"Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead",
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use deprecate for one and ValueError for the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops 😅 fixed it

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's looking really good! 👍🏽

Comment on lines 252 to 253
"Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if saying something is deprecated is a good idea in a "ValueError". Have we ever done that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's indeed rename the message here to something like "This class cannot be used with "type==ada_group", please use XXX instead"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice. Just one comment.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@yiyixuxu yiyixuxu merged commit dd4459a into main Jan 10, 2024
@yiyixuxu yiyixuxu deleted the resnet2d branch January 10, 2024 07:38
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…6166)



---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants