Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds the pipeline for pixart alpha controlnet #8857

Merged
merged 23 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5389c8c
add the controlnet pipeline for pixart alpha
raulc0399 Jul 12, 2024
abcc770
import structure for the pixart alpha controlnet pipeline
raulc0399 Jul 12, 2024
eaa2e21
use PixArtImageProcessor
raulc0399 Jul 12, 2024
3e2ec9a
moved the pixart controlnet in examples
raulc0399 Jul 22, 2024
98f63fb
rollback changes
raulc0399 Jul 22, 2024
e381b40
training script
raulc0399 Jul 22, 2024
2e6ca8c
moved pipepile to comunity folder
raulc0399 Jul 23, 2024
a6ed8f7
readme section for the pixart controlnet model and pipeline
raulc0399 Jul 24, 2024
908f615
Merge branch 'main' of https://github.com/huggingface/diffusers into …
raulc0399 Sep 16, 2024
531edde
Merge branch 'main' into main_pixart_alpha_controlnet
yiyixuxu Sep 17, 2024
8c343e2
Merge branch 'main' into main_pixart_alpha_controlnet
yiyixuxu Sep 17, 2024
aed7d3d
Merge remote-tracking branch 'src/main' into main_pixart_alpha_contro…
raulc0399 Oct 16, 2024
510a102
Merge branch 'main_pixart_alpha_controlnet' of github.com:raulc0399/d…
raulc0399 Oct 16, 2024
0d1ed9d
moved the pixart controlnet pipeline under research_projects
raulc0399 Oct 16, 2024
d7f03f9
Merge branch 'main' into main_pixart_alpha_controlnet
sayakpaul Oct 18, 2024
9b89559
make style && make quality;
lawrence-cj Oct 25, 2024
5957191
make style && make quality;
lawrence-cj Oct 25, 2024
995a79f
make style && make quality;
lawrence-cj Oct 25, 2024
146dec2
moved the file to research_projects folder
raulc0399 Oct 27, 2024
236e81d
Merge branch 'main_pixart_alpha_controlnet' of github.com:raulc0399/d…
raulc0399 Oct 27, 2024
50679b3
Merge remote-tracking branch 'tmp/main_pixart_alpha_controlnet' into …
lawrence-cj Oct 27, 2024
8bc5599
Merge remote-tracking branch 'src/main' into main_pixart_alpha_contro…
raulc0399 Oct 27, 2024
e3c5c05
Merge remote-tracking branch 'refs/remotes/tmp/main_pixart_alpha_cont…
lawrence-cj Oct 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/pixart/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
images/
output/
292 changes: 292 additions & 0 deletions examples/pixart/controlnet_pixart_alpha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
from typing import Any, Dict, Optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe pipelines can go to the /example/community folder, the training script can stay in example/pixart folder

cc @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with that plan.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be addressed first.


import torch
from torch import nn

from diffusers.models.attention import BasicTransformerBlock
from diffusers.models import PixArtTransformer2DModel
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.modeling_outputs import Transformer2DModelOutput

class PixArtControlNetAdapterBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to copy paste all the model code here into the pipeline so that the pipeline will be able to run, no?

Copy link
Contributor Author

@raulc0399 raulc0399 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pipeline code changes the sys path, so it runs
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

def __init__(
self,
block_index,

# taken from PixArtTransformer2DModel
num_attention_heads: int = 16,
attention_head_dim: int = 72,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = 1152,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: Optional[int] = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
attention_type: Optional[str] = "default",
):
super().__init__()

self.block_index = block_index
self.inner_dim = num_attention_heads * attention_head_dim

# the first block has a zero before layer
if self.block_index == 0:
self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)

self.transformer_block = BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)

self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)

def train(self, mode: bool = True):
self.transformer_block.train(mode)

if self.block_index == 0:
self.before_proj.train(mode)

self.after_proj.train(mode)

def forward(
self,
hidden_states: torch.Tensor,
controlnet_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
):
if self.block_index == 0:
controlnet_states = self.before_proj(controlnet_states)
controlnet_states = hidden_states + controlnet_states

controlnet_states_down = self.transformer_block(
hidden_states=controlnet_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
class_labels=None
)

controlnet_states_left = self.after_proj(controlnet_states_down)

return controlnet_states_left, controlnet_states_down

class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):
# N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer
@register_to_config
def __init__(self, num_layers = 13) -> None:
super().__init__()

self.num_layers = num_layers

self.controlnet_blocks = nn.ModuleList(
[
PixArtControlNetAdapterBlock(block_index=i)
for i in range(num_layers)
]
)

@classmethod
def from_transformer(cls, transformer: PixArtTransformer2DModel):
control_net = PixArtControlNetAdapterModel()

# copied the specified number of blocks from the transformer
for depth in range(control_net.num_layers):
control_net.controlnet_blocks[depth].transformer_block.load_state_dict(transformer.transformer_blocks[depth].state_dict())

return control_net

def train(self, mode: bool = True):
for block in self.controlnet_blocks:
block.train(mode)

class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
def __init__(
self,
transformer: PixArtTransformer2DModel,
controlnet: PixArtControlNetAdapterModel,
blocks_num=13,
init_from_transformer=False,
training=False
):
super().__init__()

self.blocks_num = blocks_num
self.gradient_checkpointing = False
self.register_to_config(**transformer.config)
self.training = training

if init_from_transformer:
# copies the specified number of blocks from the transformer
controlnet.from_transformer(transformer, self.blocks_num)

self.transformer = transformer
self.controlnet = controlnet

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
controlnet_cond: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):

if self.transformer.use_additional_conditions and added_cond_kwargs is None:
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.transformer.config.patch_size,
hidden_states.shape[-1] // self.transformer.config.patch_size,
)
hidden_states = self.transformer.pos_embed(hidden_states)

timestep, embedded_timestep = self.transformer.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)

if self.transformer.caption_projection is not None:
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])

controlnet_states_down = None
if controlnet_cond is not None:
controlnet_states_down = self.transformer.pos_embed(controlnet_cond)

# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:
controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[block_index - 1](
hidden_states=hidden_states, # used only in the first block
controlnet_states=controlnet_states_down,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask
)

hidden_states = hidden_states + controlnet_states_left

hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=None,
)

# 3. Output
shift, scale = (
self.transformer.scale_shift_table[None] + embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.transformer.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.transformer.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)

# unpatchify
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.transformer.config.patch_size, self.transformer.config.patch_size, self.transformer.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.transformer.out_channels, height * self.transformer.config.patch_size, width * self.transformer.config.patch_size)
)

if not return_dict:
return (output,)

return Transformer2DModelOutput(sample=output)

Loading