Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 81 additions & 144 deletions comfy/ldm/lumina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension


Expand All @@ -31,6 +32,7 @@ def __init__(
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
out_bias: bool = False,
operation_settings={},
):
"""
Expand Down Expand Up @@ -59,7 +61,7 @@ def __init__(
self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim,
dim,
bias=False,
bias=out_bias,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
Expand All @@ -70,35 +72,6 @@ def __init__(
else:
self.q_norm = self.k_norm = nn.Identity()

@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""

t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -134,8 +107,7 @@ def forward(
xq = self.q_norm(xq)
xk = self.k_norm(xk)

xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
xq, xk = apply_rope(xq, xk, freqs_cis)

n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
Expand Down Expand Up @@ -215,6 +187,8 @@ def __init__(
norm_eps: float,
qk_norm: bool,
modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings={},
) -> None:
"""
Expand All @@ -235,10 +209,10 @@ def __init__(
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
hidden_dim=dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings,
Expand All @@ -252,16 +226,27 @@ def __init__(

self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
if z_image_modulation:
self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear(
min(dim, 256),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)

def forward(
self,
Expand Down Expand Up @@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT.
"""

def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size,
Expand All @@ -340,10 +325,15 @@ def __init__(self, hidden_size, patch_size, out_channels, operation_settings={})
dtype=operation_settings.get("dtype"),
)

if z_image_modulation:
min_mod = 256
else:
min_mod = 1024

self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(hidden_size, 1024),
min(hidden_size, min_mod),
hidden_size,
bias=True,
device=operation_settings.get("device"),
Expand Down Expand Up @@ -373,12 +363,16 @@ def __init__(
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
image_model=None,
device=None,
dtype=None,
Expand All @@ -390,6 +384,8 @@ def __init__(
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple

self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels,
Expand All @@ -411,6 +407,7 @@ def __init__(
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
Expand All @@ -434,7 +431,7 @@ def __init__(
]
)

self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
Expand All @@ -457,18 +454,24 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings,
)
for layer_id in range(n_layers)
]
)
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)

if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))

assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads

Expand Down Expand Up @@ -503,108 +506,42 @@ def patchify_and_embed(
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
dtype = x[0].dtype

if cap_mask is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
else:
l_effective_cap_len = [num_tokens] * bsz
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)

if cap_mask is not None and not torch.is_floating_point(cap_mask):
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0

img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))

max_seq_len = max(
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)

position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)

for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len

rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)

h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)

position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)

# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)

img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)

for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()

# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)

# refine image
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max

padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))

if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)

padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)

padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
padded_img_mask = None
for layer in self.noise_refiner:
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)

padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis

def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
Expand All @@ -627,7 +564,7 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwa
y: (N,) tensor of text tokens/features
"""

t = self.t_embedder(t, dtype=x.dtype) # (N, D)
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t

cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
Expand Down
6 changes: 4 additions & 2 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""

def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size

Expand Down
Loading