Skip to content
Merged
213 changes: 181 additions & 32 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ..utils import USE_PEFT_BACKEND
from ..utils import USE_PEFT_BACKEND, deprecate
from .activations import get_activation
from .attention_processor import SpatialNorm
from .downsampling import ( # noqa
Expand All @@ -42,6 +42,156 @@
)


class ResnetBlockCondNorm2D(nn.Module):
r"""
A Resnet block that use normalization layer that incorporate conditioning information.

Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""

def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
eps: float = 1e-6,
non_linearity: str = "swish",
time_embedding_norm: str = "ada_group", # ada_group, spatial
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
up: bool = False,
down: bool = False,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm

conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv

if groups_out is None:
groups_out = groups

if self.time_embedding_norm == "ada_group": # ada_group
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")

self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

if self.time_embedding_norm == "ada_group": # ada_group
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial": # spatial
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")

self.dropout = torch.nn.Dropout(dropout)

conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)

self.nonlinearity = get_activation(non_linearity)

self.upsample = self.downsample = None
if self.up:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut

self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)

def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor

hidden_states = self.norm1(hidden_states, temb)

hidden_states = self.nonlinearity(hidden_states)

if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor, scale=scale)
hidden_states = self.upsample(hidden_states, scale=scale)

elif self.downsample is not None:
input_tensor = self.downsample(input_tensor, scale=scale)
hidden_states = self.downsample(hidden_states, scale=scale)

hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)

hidden_states = self.norm2(hidden_states, temb)

hidden_states = self.nonlinearity(hidden_states)

hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)

if self.conv_shortcut is not None:
input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)

output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

return output_tensor


class ResnetBlock2D(nn.Module):
r"""
A Resnet block.
Expand All @@ -58,8 +208,8 @@ class ResnetBlock2D(nn.Module):
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
Expand Down Expand Up @@ -87,7 +237,7 @@ def __init__(
eps: float = 1e-6,
non_linearity: str = "swish",
skip_time_act: bool = False,
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
time_embedding_norm: str = "default", # default, scale_shift,
kernel: Optional[torch.FloatTensor] = None,
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
Expand All @@ -97,7 +247,19 @@ def __init__(
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
if time_embedding_norm == "ada_group":
deprecate(
"ada_group",
"1.0.0",
"Passing `ada_group` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if saying something is deprecated is a good idea in a "ValueError". Have we ever done that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's indeed rename the message here to something like "This class cannot be used with "type==ada_group", please use XXX instead"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

if time_embedding_norm == "spatial":
raise ValueError(
"spatial",
"1.0.0",
"Passing `spatial` as `time_embedding_norm` is deprecated, please create `ResnetBlockCondNorm2D` instead",
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use deprecate for one and ValueError for the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops 😅 fixed it

self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
Expand All @@ -115,12 +277,7 @@ def __init__(
if groups_out is None:
groups_out = groups

if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

Expand All @@ -129,19 +286,12 @@ def __init__(
self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
self.time_emb_proj = None

if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
Expand Down Expand Up @@ -188,11 +338,7 @@ def forward(
) -> torch.FloatTensor:
hidden_states = input_tensor

if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)

hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)

if self.upsample is not None:
Expand Down Expand Up @@ -233,17 +379,20 @@ def forward(
else self.time_emb_proj(temb)[:, :, None, None]
)

if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb

if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
if self.time_embedding_norm == "default":
if temb is not None:
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)

if temb is not None and self.time_embedding_norm == "scale_shift":
elif self.time_embedding_norm == "scale_shift":
if temb is None:
raise ValueError(
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
)
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = self.norm2(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
else:
hidden_states = self.norm2(hidden_states)

hidden_states = self.nonlinearity(hidden_states)

Expand Down
Loading