From 117301372bf5673dea4ffa8e3572c3558e6203e6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Dec 2023 09:20:28 +0000 Subject: [PATCH 01/15] first draft - seperate out ada_group and spatial norm --- src/diffusers/models/resnet.py | 203 ++++++++++++++++++++++++++++----- 1 file changed, 173 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 970d2be05b7a..17648dc0e3a2 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND +from ..utils import USE_PEFT_BACKEND, deprecate from .activations import get_activation from .attention_processor import SpatialNorm from .lora import LoRACompatibleConv, LoRACompatibleLinear @@ -581,6 +581,156 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) +class ResnetBlockCondNorm2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or + "ada_group" for a stronger conditioning with scale and shift. + kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: str = "swish", + time_embedding_norm: str = "ada_group", # ada_group, spatial + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + + if groups_out is None: + groups_out = groups + + if self.time_embedding_norm == "ada_group": # ada_group + self.norm1 = AdaGroupNorm(in_channels, groups, temb_channels) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) + else: + raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") + + self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if self.time_embedding_norm == "ada_group": # ada_group + self.norm2 = AdaGroupNorm(out_channels, groups_out, temb_channels) + elif self.time_embedding_norm == "spatial": + self.norm2 = SpatialNorm(out_channels, temb_channels) + else: + raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") + + self.dropout = torch.nn.Dropout(dropout) + + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = conv_cls( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + temb: torch.FloatTensor, + scale: float = 1.0, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states, temb) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor, scale=scale) + hidden_states = self.upsample(hidden_states, scale=scale) + + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor, scale=scale) + hidden_states = self.downsample(hidden_states, scale=scale) + + hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, temb) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = ( + self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) + ) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + class ResnetBlock2D(nn.Module): r""" A Resnet block. @@ -626,7 +776,7 @@ def __init__( eps: float = 1e-6, non_linearity: str = "swish", skip_time_act: bool = False, - time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial + time_embedding_norm: str = "default", # default, scale_shift, kernel: Optional[torch.FloatTensor] = None, output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, @@ -636,7 +786,19 @@ def __init__( conv_2d_out_channels: Optional[int] = None, ): super().__init__() - self.pre_norm = pre_norm + if time_embedding_norm == "ada_group": + deprecate( + "ada_group", + "1.0.0", + "Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", + ) + if time_embedding_norm == "spatial": + deprecate( + "spatial", + "1.0.0", + "Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", + ) + self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -654,12 +816,7 @@ def __init__( if groups_out is None: groups_out = groups - if self.time_embedding_norm == "ada_group": - self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) - elif self.time_embedding_norm == "spatial": - self.norm1 = SpatialNorm(in_channels, temb_channels) - else: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -668,19 +825,12 @@ def __init__( self.time_emb_proj = linear_cls(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: self.time_emb_proj = None - if self.time_embedding_norm == "ada_group": - self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) - elif self.time_embedding_norm == "spatial": - self.norm2 = SpatialNorm(out_channels, temb_channels) - else: - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels @@ -727,11 +877,7 @@ def forward( ) -> torch.FloatTensor: hidden_states = input_tensor - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - hidden_states = self.norm1(hidden_states, temb) - else: - hidden_states = self.norm1(hidden_states) - + hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -772,16 +918,13 @@ def forward( else self.time_emb_proj(temb)[:, :, None, None] ) - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - hidden_states = self.norm2(hidden_states, temb) - else: + if self.time_embedding_norm == "default": + if temb is not None: + hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": + elif self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = self.norm2(hidden_states) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) From c79dc8ea03302d52b50098a4d08331ae4a365479 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Dec 2023 09:54:45 +0000 Subject: [PATCH 02/15] add a test line print --- src/diffusers/models/resnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 17648dc0e3a2..7f75c345b252 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -938,6 +938,7 @@ def forward( ) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + print(f" output_tensor : {output_tensor.shape}, {output_tensor.abs().sum()}") return output_tensor From b75c8e6a600e6046fa4bc45333bb25e4246032e5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Dec 2023 16:39:09 +0000 Subject: [PATCH 03/15] more prints --- src/diffusers/models/resnet.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7f75c345b252..d149514367de 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -876,9 +876,12 @@ def forward( scale: float = 1.0, ) -> torch.FloatTensor: hidden_states = input_tensor + print(f" ") + print(f" hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) + print(f" norm1 -> hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -906,6 +909,7 @@ def forward( if isinstance(self.downsample, Downsample2D) else self.downsample(hidden_states) ) + print(f" up/down -> hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) @@ -917,7 +921,9 @@ def forward( if not USE_PEFT_BACKEND else self.time_emb_proj(temb)[:, :, None, None] ) - + print(f" temb: {temb.shape}, {temb.abs().sum()}") + + print(f" self.time_embedding_norm: {self.time_embedding_norm}") if self.time_embedding_norm == "default": if temb is not None: hidden_states = hidden_states + temb @@ -928,6 +934,7 @@ def forward( hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) + print(f" time_embedding_norm -> hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) From b89788e7263008cf3fc9502b4b8d8d829f29dcf2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Dec 2023 23:31:08 +0000 Subject: [PATCH 04/15] fi --- src/diffusers/models/resnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d149514367de..2addfcf1c8b4 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -929,9 +929,13 @@ def forward( hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) elif self.time_embedding_norm == "scale_shift": + if temb is None: + raise ValueError(f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}") scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = self.norm2(hidden_states) hidden_states = hidden_states * (1 + scale) + shift + else: + hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) print(f" time_embedding_norm -> hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") From ef8b90078b63c38c1acc36a3af00f9016a0da539 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 13 Dec 2023 23:46:45 +0000 Subject: [PATCH 05/15] remove testing lines --- src/diffusers/models/resnet.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 2addfcf1c8b4..b4357ef2ba17 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -876,12 +876,9 @@ def forward( scale: float = 1.0, ) -> torch.FloatTensor: hidden_states = input_tensor - print(f" ") - print(f" hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) - print(f" norm1 -> hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -909,7 +906,6 @@ def forward( if isinstance(self.downsample, Downsample2D) else self.downsample(hidden_states) ) - print(f" up/down -> hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) @@ -921,16 +917,16 @@ def forward( if not USE_PEFT_BACKEND else self.time_emb_proj(temb)[:, :, None, None] ) - print(f" temb: {temb.shape}, {temb.abs().sum()}") - - print(f" self.time_embedding_norm: {self.time_embedding_norm}") + if self.time_embedding_norm == "default": if temb is not None: hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) elif self.time_embedding_norm == "scale_shift": if temb is None: - raise ValueError(f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}") + raise ValueError( + f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" + ) scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = self.norm2(hidden_states) hidden_states = hidden_states * (1 + scale) + shift @@ -938,7 +934,6 @@ def forward( hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) - print(f" time_embedding_norm -> hidden_states: {hidden_states.shape}, {hidden_states.abs().sum()}") hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) @@ -949,7 +944,6 @@ def forward( ) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - print(f" output_tensor : {output_tensor.shape}, {output_tensor.abs().sum()}") return output_tensor From 282bce4346d97161d65cf9ae209fd727a6f91405 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 14 Dec 2023 01:40:16 +0000 Subject: [PATCH 06/15] get_resnet_block() --- src/diffusers/models/unet_2d_blocks.py | 183 +++++++++++++++++++------ 1 file changed, 140 insertions(+), 43 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index e404cef224ff..c21abd6b4f55 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -24,7 +24,16 @@ from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .normalization import AdaGroupNorm -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from .resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) from .transformer_2d import Transformer2DModel @@ -465,6 +474,98 @@ def get_up_block( raise ValueError(f"{up_block_type} does not exist.") +def get_resnet_block( + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial + kernel: Optional[torch.FloatTensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, +) -> nn.Module: + r""" + Get a Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or + "ada_group" for a stronger conditioning with scale and shift. + kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in + the output. If None, same as `out_channels`. + """ + + if time_embedding_norm == "ada_group" or time_embedding_norm == "spatial": + return ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + conv_shortcut=conv_shortcut, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + non_linearity=non_linearity, + time_embedding_norm=time_embedding_norm, + output_scale_factor=output_scale_factor, + use_in_shortcut=use_in_shortcut, + up=up, + down=down, + conv_shortcut_bias=conv_shortcut_bias, + conv_2d_out_channels=conv_2d_out_channels, + ) + else: + return ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + conv_shortcut=conv_shortcut, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=eps, + non_linearity=non_linearity, + skip_time_act=skip_time_act, + time_embedding_norm=time_embedding_norm, + kernel=kernel, + output_scale_factor=output_scale_factor, + use_in_shortcut=use_in_shortcut, + up=up, + down=down, + conv_shortcut_bias=conv_shortcut_bias, + conv_2d_out_channels=conv_2d_out_channels, + ) + + class AutoencoderTinyBlock(nn.Module): """ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU @@ -558,7 +659,7 @@ def __init__( # there is always at least one resnet resnets = [ - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -600,7 +701,7 @@ def __init__( attentions.append(None) resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -660,7 +761,7 @@ def __init__( # there is always at least one resnet resnets = [ - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -702,7 +803,7 @@ def __init__( ) ) resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -803,7 +904,7 @@ def __init__( # there is always at least one resnet resnets = [ - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -840,7 +941,7 @@ def __init__( ) ) resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -929,7 +1030,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -971,7 +1072,7 @@ def __init__( elif downsample_type == "resnet": self.downsamplers = nn.ModuleList( [ - ResnetBlock2D( + get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1057,7 +1158,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1204,7 +1305,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1291,7 +1392,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=None, @@ -1359,7 +1460,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=None, @@ -1443,7 +1544,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1473,7 +1574,7 @@ def __init__( ) if add_downsample: - self.resnet_down = ResnetBlock2D( + self.resnet_down = get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1544,7 +1645,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1560,7 +1661,7 @@ def __init__( ) if add_downsample: - self.resnet_down = ResnetBlock2D( + self.resnet_down = get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1630,7 +1731,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1650,7 +1751,7 @@ def __init__( if add_downsample: self.downsamplers = nn.ModuleList( [ - ResnetBlock2D( + get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1741,7 +1842,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1781,7 +1882,7 @@ def __init__( if add_downsample: self.downsamplers = nn.ModuleList( [ - ResnetBlock2D( + get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1889,14 +1990,13 @@ def __init__( groups_out = out_channels // resnet_group_size resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, - eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, @@ -1975,14 +2075,13 @@ def __init__( groups_out = out_channels // resnet_group_size resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, - eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, @@ -2110,7 +2209,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2146,7 +2245,7 @@ def __init__( elif upsample_type == "resnet": self.upsamplers = nn.ModuleList( [ - ResnetBlock2D( + get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2235,7 +2334,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2393,7 +2492,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2501,7 +2600,7 @@ def __init__( input_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2569,7 +2668,7 @@ def __init__( input_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2650,7 +2749,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2688,7 +2787,7 @@ def __init__( self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: - self.resnet_up = ResnetBlock2D( + self.resnet_up = get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2779,7 +2878,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2796,7 +2895,7 @@ def __init__( self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: - self.resnet_up = ResnetBlock2D( + self.resnet_up = get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2885,7 +2984,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2905,7 +3004,7 @@ def __init__( if add_upsample: self.upsamplers = nn.ModuleList( [ - ResnetBlock2D( + get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -3004,7 +3103,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock2D( + get_resnet_block( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -3044,7 +3143,7 @@ def __init__( if add_upsample: self.upsamplers = nn.ModuleList( [ - ResnetBlock2D( + get_resnet_block( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -3159,11 +3258,10 @@ def __init__( groups_out = out_channels // resnet_group_size resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, - eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, @@ -3267,12 +3365,11 @@ def __init__( conv_2d_out_channels = None resnets.append( - ResnetBlock2D( + ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, conv_2d_out_channels=conv_2d_out_channels, temb_channels=temb_channels, - eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, From 3c7c361f4ee0bf7ef2cfcb6d01cd5c57d8738830 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 14 Dec 2023 07:42:21 +0000 Subject: [PATCH 07/15] fix ada_group norm --- src/diffusers/models/resnet.py | 5 +++-- src/diffusers/models/unet_2d_blocks.py | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index b4357ef2ba17..e99742390eed 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -622,6 +622,7 @@ def __init__( temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, + eps: float = 1e-6, non_linearity: str = "swish", time_embedding_norm: str = "ada_group", # ada_group, spatial output_scale_factor: float = 1.0, @@ -647,7 +648,7 @@ def __init__( groups_out = groups if self.time_embedding_norm == "ada_group": # ada_group - self.norm1 = AdaGroupNorm(in_channels, groups, temb_channels) + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) elif self.time_embedding_norm == "spatial": self.norm1 = SpatialNorm(in_channels, temb_channels) else: @@ -656,7 +657,7 @@ def __init__( self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.time_embedding_norm == "ada_group": # ada_group - self.norm2 = AdaGroupNorm(out_channels, groups_out, temb_channels) + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) elif self.time_embedding_norm == "spatial": self.norm2 = SpatialNorm(out_channels, temb_channels) else: diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index c21abd6b4f55..5e503fd38024 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -534,6 +534,7 @@ def get_resnet_block( temb_channels=temb_channels, groups=groups, groups_out=groups_out, + eps=eps, non_linearity=non_linearity, time_embedding_norm=time_embedding_norm, output_scale_factor=output_scale_factor, @@ -1997,6 +1998,7 @@ def __init__( temb_channels=temb_channels, groups=groups, groups_out=groups_out, + eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, @@ -2082,6 +2084,7 @@ def __init__( temb_channels=temb_channels, groups=groups, groups_out=groups_out, + eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, @@ -3262,6 +3265,7 @@ def __init__( in_channels=in_channels, out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, + eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, @@ -3370,6 +3374,7 @@ def __init__( out_channels=out_channels, conv_2d_out_channels=conv_2d_out_channels, temb_channels=temb_channels, + eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, From d6e588175fda87e9ef7568956a99e8e336fa5995 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 14 Dec 2023 08:03:09 +0000 Subject: [PATCH 08/15] update doc string --- src/diffusers/models/resnet.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e99742390eed..29ff9eaa2dd0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -583,7 +583,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class ResnetBlockCondNorm2D(nn.Module): r""" - A Resnet block. + A Resnet block that use normalization layer that incorporate conditioning information. Parameters: in_channels (`int`): The number of channels in the input. @@ -596,9 +596,8 @@ class ResnetBlockCondNorm2D(nn.Module): The number of groups to use for the second normalization layer. if set to None, same as `groups`. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. - time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. - By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or - "ada_group" for a stronger conditioning with scale and shift. + time_embedding_norm (`str`, *optional*, default to `"ada_group"` ): + The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial". kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. @@ -658,7 +657,7 @@ def __init__( if self.time_embedding_norm == "ada_group": # ada_group self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) - elif self.time_embedding_norm == "spatial": + elif self.time_embedding_norm == "spatial": # spatial self.norm2 = SpatialNorm(out_channels, temb_channels) else: raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") @@ -748,8 +747,8 @@ class ResnetBlock2D(nn.Module): eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. - By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or - "ada_group" for a stronger conditioning with scale and shift. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" + for a stronger conditioning with scale and shift. kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. From 8e0c9c9075d30ee730ce74029769b9aee873ba04 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 28 Dec 2023 05:07:57 +0000 Subject: [PATCH 09/15] remove get_resnet_block function --- src/diffusers/models/unet_2d_blocks.py | 322 +++++++++++-------------- 1 file changed, 147 insertions(+), 175 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 5e503fd38024..d0a4d447fe6a 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -474,99 +474,6 @@ def get_up_block( raise ValueError(f"{up_block_type} does not exist.") -def get_resnet_block( - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - pre_norm: bool = True, - eps: float = 1e-6, - non_linearity: str = "swish", - skip_time_act: bool = False, - time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial - kernel: Optional[torch.FloatTensor] = None, - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, -) -> nn.Module: - r""" - Get a Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv2d layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - groups_out (`int`, *optional*, default to None): - The number of groups to use for the second normalization layer. if set to None, same as `groups`. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. - time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. - By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or - "ada_group" for a stronger conditioning with scale and shift. - kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see - [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. - output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. - use_in_shortcut (`bool`, *optional*, default to `True`): - If `True`, add a 1x1 nn.conv2d layer for skip-connection. - up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. - down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. - conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the - `conv_shortcut` output. - conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in - the output. If None, same as `out_channels`. - """ - - if time_embedding_norm == "ada_group" or time_embedding_norm == "spatial": - return ResnetBlockCondNorm2D( - in_channels=in_channels, - out_channels=out_channels, - conv_shortcut=conv_shortcut, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=eps, - non_linearity=non_linearity, - time_embedding_norm=time_embedding_norm, - output_scale_factor=output_scale_factor, - use_in_shortcut=use_in_shortcut, - up=up, - down=down, - conv_shortcut_bias=conv_shortcut_bias, - conv_2d_out_channels=conv_2d_out_channels, - ) - else: - return ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - conv_shortcut=conv_shortcut, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=eps, - non_linearity=non_linearity, - skip_time_act=skip_time_act, - time_embedding_norm=time_embedding_norm, - kernel=kernel, - output_scale_factor=output_scale_factor, - use_in_shortcut=use_in_shortcut, - up=up, - down=down, - conv_shortcut_bias=conv_shortcut_bias, - conv_2d_out_channels=conv_2d_out_channels, - ) - - class AutoencoderTinyBlock(nn.Module): """ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU @@ -660,7 +567,7 @@ def __init__( # there is always at least one resnet resnets = [ - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -702,7 +609,7 @@ def __init__( attentions.append(None) resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -762,7 +669,7 @@ def __init__( # there is always at least one resnet resnets = [ - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -804,7 +711,7 @@ def __init__( ) ) resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -905,7 +812,7 @@ def __init__( # there is always at least one resnet resnets = [ - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -942,7 +849,7 @@ def __init__( ) ) resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -1031,7 +938,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1073,7 +980,7 @@ def __init__( elif downsample_type == "resnet": self.downsamplers = nn.ModuleList( [ - get_resnet_block( + ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1159,7 +1066,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1306,7 +1213,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1392,20 +1299,36 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels - resnets.append( - get_resnet_block( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.resnets = nn.ModuleList(resnets) @@ -1460,20 +1383,36 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels - resnets.append( - get_resnet_block( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) attentions.append( Attention( out_channels, @@ -1545,7 +1484,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1575,7 +1514,7 @@ def __init__( ) if add_downsample: - self.resnet_down = get_resnet_block( + self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1646,7 +1585,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1662,7 +1601,7 @@ def __init__( ) if add_downsample: - self.resnet_down = get_resnet_block( + self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1732,7 +1671,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1752,7 +1691,7 @@ def __init__( if add_downsample: self.downsamplers = nn.ModuleList( [ - get_resnet_block( + ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1843,7 +1782,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1883,7 +1822,7 @@ def __init__( if add_downsample: self.downsamplers = nn.ModuleList( [ - get_resnet_block( + ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2212,7 +2151,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2248,7 +2187,7 @@ def __init__( elif upsample_type == "resnet": self.upsamplers = nn.ModuleList( [ - get_resnet_block( + ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2337,7 +2276,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2495,7 +2434,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2602,20 +2541,36 @@ def __init__( for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels - resnets.append( - get_resnet_block( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.resnets = nn.ModuleList(resnets) @@ -2670,20 +2625,37 @@ def __init__( for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels - resnets.append( - get_resnet_block( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( Attention( out_channels, @@ -2752,7 +2724,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2790,7 +2762,7 @@ def __init__( self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: - self.resnet_up = get_resnet_block( + self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2881,7 +2853,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2898,7 +2870,7 @@ def __init__( self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: - self.resnet_up = get_resnet_block( + self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -2987,7 +2959,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -3007,7 +2979,7 @@ def __init__( if add_upsample: self.upsamplers = nn.ModuleList( [ - get_resnet_block( + ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -3106,7 +3078,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - get_resnet_block( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -3146,7 +3118,7 @@ def __init__( if add_upsample: self.upsamplers = nn.ModuleList( [ - get_resnet_block( + ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, From d5f2e822be3a3b537f9005d869c81c8d8e975706 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 28 Dec 2023 05:38:58 +0000 Subject: [PATCH 10/15] fix --- src/diffusers/models/unet_2d_blocks.py | 88 +++++++++++++++++--------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index d0a4d447fe6a..470a021165ac 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -566,20 +566,35 @@ def __init__( attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] attentions = [] if attention_head_dim is None: @@ -608,20 +623,35 @@ def __init__( else: attentions.append(None) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1311,7 +1341,6 @@ def __init__( time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ) else: @@ -1395,7 +1424,6 @@ def __init__( time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ) else: @@ -2553,7 +2581,6 @@ def __init__( time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ) else: @@ -2637,7 +2664,6 @@ def __init__( time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ) else: From 763ac1de70054997ce503c3a6da02d20659278d7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 28 Dec 2023 05:45:54 +0000 Subject: [PATCH 11/15] Fix copies --- .../versatile_diffusion/modeling_text_unet.py | 85 +++++++++++++------ 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 8ac63636df86..3013bd890fa1 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -31,6 +31,7 @@ TimestepEmbedding, Timesteps, ) +from ...models.resnet import ResnetBlockCondNorm2D from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -2117,20 +2118,35 @@ def __init__( attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet - resnets = [ - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] attentions = [] if attention_head_dim is None: @@ -2159,20 +2175,35 @@ def __init__( else: attentions.append(None) - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) From 34ddd7e985ef33995036d29ce2498b6d4dd3444f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 28 Dec 2023 05:51:00 +0000 Subject: [PATCH 12/15] style --- .../deprecated/versatile_diffusion/modeling_text_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index f6b1460b163f..c6c7826075d4 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -31,13 +31,13 @@ TimestepEmbedding, Timesteps, ) - from ....models.resnet import ResnetBlockCondNorm2D from ....models.transformer_2d import Transformer2DModel from ....models.unet_2d_condition import UNet2DConditionOutput from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import apply_freeu + logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 20196a21056033743ab6cb7ff4fd85bffa400d60 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 2 Jan 2024 11:36:22 -1000 Subject: [PATCH 13/15] Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen --- src/diffusers/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 1ba6adb6b58c..7727799a51cc 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -254,7 +254,7 @@ def __init__( "Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", ) if time_embedding_norm == "spatial": - deprecate( + raise ValueError( "spatial", "1.0.0", "Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", From d69b577ccc6acdf61e119c5193e737c4e8cbe17b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 3 Jan 2024 17:07:59 +0000 Subject: [PATCH 14/15] deprecate -> valueerror --- src/diffusers/models/resnet.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7727799a51cc..2d4c91021b6a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND, deprecate +from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .attention_processor import SpatialNorm from .downsampling import ( # noqa @@ -248,15 +248,11 @@ def __init__( ): super().__init__() if time_embedding_norm == "ada_group": - deprecate( - "ada_group", - "1.0.0", + raise ValueError( "Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", ) if time_embedding_norm == "spatial": raise ValueError( - "spatial", - "1.0.0", "Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", ) From 7f4f1698ea47d4491e217e5ba82d54ce4a6663f4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 8 Jan 2024 10:05:55 +0000 Subject: [PATCH 15/15] update message --- src/diffusers/models/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 2d4c91021b6a..3b8718f39940 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -249,11 +249,11 @@ def __init__( super().__init__() if time_embedding_norm == "ada_group": raise ValueError( - "Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", + "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", ) if time_embedding_norm == "spatial": raise ValueError( - "Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead", + "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", ) self.pre_norm = True