-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Refactor] splitingResnetBlock2D into multiple blocks
#6166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
1173013
first draft - seperate out ada_group and spatial norm
c79dc8e
add a test line print
b75c8e6
more prints
b89788e
fi
ef8b900
remove testing lines
282bce4
get_resnet_block()
3c7c361
fix ada_group norm
d6e5881
update doc string
8e0c9c9
remove get_resnet_block function
d5f2e82
fix
763ac1d
Fix copies
e69a56d
merge main
34ddd7e
style
20196a2
Update src/diffusers/models/resnet.py
yiyixuxu d69b577
deprecate -> valueerror
b0af51e
Merge branch 'main' into resnet2d
sayakpaul 7f4f169
update message
10f7098
Merge branch 'resnet2d' of github.com:huggingface/diffusers into resn…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from ..utils import USE_PEFT_BACKEND | ||
| from ..utils import USE_PEFT_BACKEND, deprecate | ||
| from .activations import get_activation | ||
| from .attention_processor import SpatialNorm | ||
| from .downsampling import ( # noqa | ||
|
|
@@ -42,6 +42,156 @@ | |
| ) | ||
|
|
||
|
|
||
| class ResnetBlockCondNorm2D(nn.Module): | ||
| r""" | ||
| A Resnet block that use normalization layer that incorporate conditioning information. | ||
|
|
||
| Parameters: | ||
| in_channels (`int`): The number of channels in the input. | ||
| out_channels (`int`, *optional*, default to be `None`): | ||
| The number of output channels for the first conv2d layer. If None, same as `in_channels`. | ||
| dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. | ||
| temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. | ||
| groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. | ||
| groups_out (`int`, *optional*, default to None): | ||
| The number of groups to use for the second normalization layer. if set to None, same as `groups`. | ||
| eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. | ||
| non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. | ||
| time_embedding_norm (`str`, *optional*, default to `"ada_group"` ): | ||
| The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial". | ||
| kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see | ||
| [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. | ||
| output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. | ||
| use_in_shortcut (`bool`, *optional*, default to `True`): | ||
| If `True`, add a 1x1 nn.conv2d layer for skip-connection. | ||
| up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. | ||
| down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. | ||
| conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the | ||
| `conv_shortcut` output. | ||
| conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. | ||
| If None, same as `out_channels`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| in_channels: int, | ||
| out_channels: Optional[int] = None, | ||
| conv_shortcut: bool = False, | ||
| dropout: float = 0.0, | ||
| temb_channels: int = 512, | ||
| groups: int = 32, | ||
| groups_out: Optional[int] = None, | ||
| eps: float = 1e-6, | ||
| non_linearity: str = "swish", | ||
| time_embedding_norm: str = "ada_group", # ada_group, spatial | ||
| output_scale_factor: float = 1.0, | ||
| use_in_shortcut: Optional[bool] = None, | ||
| up: bool = False, | ||
| down: bool = False, | ||
| conv_shortcut_bias: bool = True, | ||
| conv_2d_out_channels: Optional[int] = None, | ||
| ): | ||
| super().__init__() | ||
| self.in_channels = in_channels | ||
| out_channels = in_channels if out_channels is None else out_channels | ||
| self.out_channels = out_channels | ||
| self.use_conv_shortcut = conv_shortcut | ||
| self.up = up | ||
| self.down = down | ||
| self.output_scale_factor = output_scale_factor | ||
| self.time_embedding_norm = time_embedding_norm | ||
|
|
||
| conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv | ||
|
|
||
| if groups_out is None: | ||
| groups_out = groups | ||
|
|
||
| if self.time_embedding_norm == "ada_group": # ada_group | ||
| self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) | ||
| elif self.time_embedding_norm == "spatial": | ||
| self.norm1 = SpatialNorm(in_channels, temb_channels) | ||
| else: | ||
| raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") | ||
|
|
||
| self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | ||
|
|
||
| if self.time_embedding_norm == "ada_group": # ada_group | ||
| self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) | ||
| elif self.time_embedding_norm == "spatial": # spatial | ||
| self.norm2 = SpatialNorm(out_channels, temb_channels) | ||
| else: | ||
| raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") | ||
|
|
||
| self.dropout = torch.nn.Dropout(dropout) | ||
|
|
||
| conv_2d_out_channels = conv_2d_out_channels or out_channels | ||
| self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) | ||
|
|
||
| self.nonlinearity = get_activation(non_linearity) | ||
|
|
||
| self.upsample = self.downsample = None | ||
| if self.up: | ||
| self.upsample = Upsample2D(in_channels, use_conv=False) | ||
| elif self.down: | ||
| self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") | ||
|
|
||
| self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut | ||
|
|
||
| self.conv_shortcut = None | ||
| if self.use_in_shortcut: | ||
| self.conv_shortcut = conv_cls( | ||
| in_channels, | ||
| conv_2d_out_channels, | ||
| kernel_size=1, | ||
| stride=1, | ||
| padding=0, | ||
| bias=conv_shortcut_bias, | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_tensor: torch.FloatTensor, | ||
| temb: torch.FloatTensor, | ||
| scale: float = 1.0, | ||
| ) -> torch.FloatTensor: | ||
| hidden_states = input_tensor | ||
|
|
||
| hidden_states = self.norm1(hidden_states, temb) | ||
|
|
||
| hidden_states = self.nonlinearity(hidden_states) | ||
|
|
||
| if self.upsample is not None: | ||
| # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | ||
| if hidden_states.shape[0] >= 64: | ||
| input_tensor = input_tensor.contiguous() | ||
| hidden_states = hidden_states.contiguous() | ||
| input_tensor = self.upsample(input_tensor, scale=scale) | ||
| hidden_states = self.upsample(hidden_states, scale=scale) | ||
|
|
||
| elif self.downsample is not None: | ||
| input_tensor = self.downsample(input_tensor, scale=scale) | ||
| hidden_states = self.downsample(hidden_states, scale=scale) | ||
|
|
||
| hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) | ||
|
|
||
| hidden_states = self.norm2(hidden_states, temb) | ||
|
|
||
| hidden_states = self.nonlinearity(hidden_states) | ||
|
|
||
| hidden_states = self.dropout(hidden_states) | ||
| hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) | ||
|
|
||
| if self.conv_shortcut is not None: | ||
| input_tensor = ( | ||
| self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) | ||
| ) | ||
|
|
||
| output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | ||
|
|
||
| return output_tensor | ||
|
|
||
|
|
||
| class ResnetBlock2D(nn.Module): | ||
| r""" | ||
| A Resnet block. | ||
|
|
@@ -58,8 +208,8 @@ class ResnetBlock2D(nn.Module): | |
| eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. | ||
| non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. | ||
| time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. | ||
| By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or | ||
| "ada_group" for a stronger conditioning with scale and shift. | ||
| By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" | ||
| for a stronger conditioning with scale and shift. | ||
| kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see | ||
| [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. | ||
| output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. | ||
|
|
@@ -87,7 +237,7 @@ def __init__( | |
| eps: float = 1e-6, | ||
| non_linearity: str = "swish", | ||
| skip_time_act: bool = False, | ||
| time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial | ||
| time_embedding_norm: str = "default", # default, scale_shift, | ||
| kernel: Optional[torch.FloatTensor] = None, | ||
| output_scale_factor: float = 1.0, | ||
| use_in_shortcut: Optional[bool] = None, | ||
|
|
@@ -97,7 +247,19 @@ def __init__( | |
| conv_2d_out_channels: Optional[int] = None, | ||
| ): | ||
| super().__init__() | ||
| self.pre_norm = pre_norm | ||
| if time_embedding_norm == "ada_group": | ||
| deprecate( | ||
| "ada_group", | ||
| "1.0.0", | ||
| "Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", | ||
| ) | ||
| if time_embedding_norm == "spatial": | ||
| raise ValueError( | ||
| "spatial", | ||
| "1.0.0", | ||
| "Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", | ||
| ) | ||
|
|
||
|
||
| self.pre_norm = True | ||
| self.in_channels = in_channels | ||
| out_channels = in_channels if out_channels is None else out_channels | ||
|
|
@@ -115,12 +277,7 @@ def __init__( | |
| if groups_out is None: | ||
| groups_out = groups | ||
|
|
||
| if self.time_embedding_norm == "ada_group": | ||
| self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) | ||
| elif self.time_embedding_norm == "spatial": | ||
| self.norm1 = SpatialNorm(in_channels, temb_channels) | ||
| else: | ||
| self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) | ||
| self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) | ||
|
|
||
| self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | ||
|
|
||
|
|
@@ -129,19 +286,12 @@ def __init__( | |
| self.time_emb_proj = linear_cls(temb_channels, out_channels) | ||
| elif self.time_embedding_norm == "scale_shift": | ||
| self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) | ||
| elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": | ||
| self.time_emb_proj = None | ||
| else: | ||
| raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") | ||
| else: | ||
| self.time_emb_proj = None | ||
|
|
||
| if self.time_embedding_norm == "ada_group": | ||
| self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) | ||
| elif self.time_embedding_norm == "spatial": | ||
| self.norm2 = SpatialNorm(out_channels, temb_channels) | ||
| else: | ||
| self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) | ||
| self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) | ||
|
|
||
| self.dropout = torch.nn.Dropout(dropout) | ||
| conv_2d_out_channels = conv_2d_out_channels or out_channels | ||
|
|
@@ -188,11 +338,7 @@ def forward( | |
| ) -> torch.FloatTensor: | ||
| hidden_states = input_tensor | ||
|
|
||
| if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": | ||
| hidden_states = self.norm1(hidden_states, temb) | ||
| else: | ||
| hidden_states = self.norm1(hidden_states) | ||
|
|
||
| hidden_states = self.norm1(hidden_states) | ||
| hidden_states = self.nonlinearity(hidden_states) | ||
|
|
||
| if self.upsample is not None: | ||
|
|
@@ -233,17 +379,20 @@ def forward( | |
| else self.time_emb_proj(temb)[:, :, None, None] | ||
| ) | ||
|
|
||
| if temb is not None and self.time_embedding_norm == "default": | ||
| hidden_states = hidden_states + temb | ||
|
|
||
| if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": | ||
| hidden_states = self.norm2(hidden_states, temb) | ||
| else: | ||
| if self.time_embedding_norm == "default": | ||
| if temb is not None: | ||
| hidden_states = hidden_states + temb | ||
| hidden_states = self.norm2(hidden_states) | ||
|
|
||
| if temb is not None and self.time_embedding_norm == "scale_shift": | ||
| elif self.time_embedding_norm == "scale_shift": | ||
| if temb is None: | ||
| raise ValueError( | ||
| f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" | ||
| ) | ||
| scale, shift = torch.chunk(temb, 2, dim=1) | ||
| hidden_states = self.norm2(hidden_states) | ||
| hidden_states = hidden_states * (1 + scale) + shift | ||
| else: | ||
| hidden_states = self.norm2(hidden_states) | ||
|
|
||
| hidden_states = self.nonlinearity(hidden_states) | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if saying something is deprecated is a good idea in a "ValueError". Have we ever done that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's indeed rename the message here to something like "This class cannot be used with "type==ada_group", please use XXX instead"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated!