Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ The following table shows which models are currently supported by each accelerat
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.1-Kontext-dev** | `black-forest-labs/FLUX.1-Kontext-dev` | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **GLM-Image** | `zai-org/GLM-Image` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ |

### VideoGen

Expand Down
250 changes: 229 additions & 21 deletions vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from collections.abc import Iterable
from enum import Enum
from typing import Any

import torch
import torch.nn as nn
from diffusers.models.attention import FeedForward
import torch.nn.functional as F
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from vllm_omni.diffusion.attention.layer import Attention
Expand All @@ -21,6 +26,72 @@
logger = init_logger(__name__)


def _positive_divisors(n: int) -> set[int]:
if n <= 0:
return set()
divs: set[int] = set()
for d in range(1, int(math.isqrt(n)) + 1):
if n % d == 0:
divs.add(d)
divs.add(n // d)
return divs


def validate_glm_image_tp_constraints(
*,
dim: int,
num_heads: int,
ffn_hidden_dim: int,
tensor_parallel_size: int,
) -> list[int]:
"""Validate GLM-Image TP constraints without requiring a distributed context.

Args:
dim: Model hidden dimension
num_heads: Number of attention heads
ffn_hidden_dim: FFN hidden dimension
tensor_parallel_size: TP size to validate against

Returns:
List of supported TP candidates

Raises:
ValueError: If constraints are violated
"""
tp_size = int(tensor_parallel_size)
if tp_size <= 0:
raise ValueError(f"tensor_parallel_size must be > 0, got {tp_size}")

if dim % tp_size != 0:
supported = sorted(_positive_divisors(dim))
raise ValueError(
f"GLM-Image requires dim % tensor_parallel_size == 0, "
f"but got dim={dim}, tp={tp_size}. "
f"Supported tp candidates by dim: {supported}"
)

if num_heads % tp_size != 0:
supported = sorted(_positive_divisors(num_heads))
raise ValueError(
f"GLM-Image requires num_heads % tensor_parallel_size == 0, "
f"but got num_heads={num_heads}, tp={tp_size}. "
f"Supported tp candidates by num_heads: {supported}"
)

if ffn_hidden_dim % tp_size != 0:
supported = sorted(_positive_divisors(ffn_hidden_dim))
raise ValueError(
f"GLM-Image requires ffn_hidden_dim % tensor_parallel_size == 0, "
f"but got ffn_hidden_dim={ffn_hidden_dim}, tp={tp_size}. "
f"Supported tp candidates by ffn_hidden_dim: {supported}"
)

supported_tp_candidates = sorted(
_positive_divisors(num_heads) & _positive_divisors(dim) & _positive_divisors(ffn_hidden_dim)
)
return supported_tp_candidates


class GlmImageImageProjector(nn.Module):
"""Projects latent image patches to transformer hidden dimension."""

Expand Down Expand Up @@ -330,35 +401,44 @@ def __init__(
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.total_num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = num_heads * head_dim

# QKV projection (fused for efficiency)
self.to_qkv = QKVParallelLinear(
hidden_size=dim,
head_size=head_dim,
total_num_heads=num_heads,
disable_tp=True,
total_num_kv_heads=num_heads,
bias=True,
return_bias=False,
)

# QK normalization (LayerNorm, not RMSNorm for GLM-Image)
self.norm_q = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps)
self.norm_k = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps)

# Output projection
self.to_out = nn.Sequential(
nn.Linear(self.inner_dim, dim, bias=out_bias),
nn.Dropout(0.0),
self.to_out = nn.ModuleList(
[
RowParallelLinear(
dim,
dim,
bias=out_bias,
input_is_parallel=True,
return_bias=False,
),
nn.Dropout(0.0),
]
)

# Attention
self.attn = Attention(
num_heads=num_heads,
num_heads=self.to_qkv.num_heads,
head_size=head_dim,
softmax_scale=1.0 / (head_dim**0.5),
causal=False,
num_kv_heads=self.to_qkv.num_kv_heads,
)

def forward(
Expand Down Expand Up @@ -390,13 +470,15 @@ def forward(
hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1)

# QKV projection
qkv, _ = self.to_qkv(hidden_states_combined)
query, key, value = qkv.chunk(3, dim=-1)
qkv = self.to_qkv(hidden_states_combined)
q_size = self.to_qkv.num_heads * self.head_dim
kv_size = self.to_qkv.num_kv_heads * self.head_dim
query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1)

# Reshape: [B, S, H*D] -> [B, S, H, D]
query = query.unflatten(-1, (self.num_heads, -1))
key = key.unflatten(-1, (self.num_heads, -1))
value = value.unflatten(-1, (self.num_heads, -1))
query = query.unflatten(-1, (self.to_qkv.num_heads, -1))
key = key.unflatten(-1, (self.to_qkv.num_kv_heads, -1))
value = value.unflatten(-1, (self.to_qkv.num_kv_heads, -1))

# QK normalization
query = self.norm_q(query).to(dtype=dtype)
Expand Down Expand Up @@ -431,7 +513,8 @@ def forward(
hidden_states_out = hidden_states_out.to(dtype)

# Output projection
hidden_states_out = self.to_out(hidden_states_out)
for module in self.to_out:
hidden_states_out = module(hidden_states_out)

# Split back to text and image
encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :]
Expand All @@ -440,6 +523,100 @@ def forward(
return hidden_states_out, encoder_hidden_states_out


class ColumnParallelGELU(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
*,
approximate: str = "none",
bias: bool = True,
):
super().__init__()
self.proj = ColumnParallelLinear(
dim_in,
dim_out,
bias=bias,
gather_output=False,
return_bias=False,
)
self.approximate = approximate

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return F.gelu(x, approximate=self.approximate)


class ColumnParallelSiLU(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
*,
bias: bool = True,
):
super().__init__()
self.proj = ColumnParallelLinear(
dim_in,
dim_out,
bias=bias,
gather_output=False,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return F.silu(x)


class GlmImageFeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: int | None = None,
mult: int = 4,
inner_dim: int | None = None,
bias: bool = True,
activation_fn: str = "gelu",
):
super().__init__()
inner_dim = inner_dim or int(dim * mult)
dim_out = dim_out or dim

if activation_fn == "linear-silu":
layers: list[nn.Module] = [
ColumnParallelSiLU(dim, inner_dim, bias=bias),
nn.Identity(),
RowParallelLinear(
inner_dim,
dim_out,
bias=bias,
input_is_parallel=True,
return_bias=False,
),
]
else:
approximate = "tanh" if activation_fn == "gelu-approximate" else "none"
layers = [
ColumnParallelGELU(dim, inner_dim, approximate=approximate, bias=bias),
nn.Identity(),
RowParallelLinear(
inner_dim,
dim_out,
bias=bias,
input_is_parallel=True,
return_bias=False,
),
]

self.net = nn.ModuleList(layers)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states


class GlmImageTransformerBlock(nn.Module):
"""Single transformer block for GLM-Image."""

Expand All @@ -449,6 +626,7 @@ def __init__(
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
ffn_hidden_dim: int | None = None,
) -> None:
super().__init__()

Expand All @@ -463,7 +641,7 @@ def __init__(
# 2. Feedforward
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.ff = GlmImageFeedForward(dim=dim, dim_out=dim, inner_dim=ffn_hidden_dim, activation_fn="gelu-approximate")

def forward(
self,
Expand Down Expand Up @@ -564,26 +742,50 @@ def __init__(

# Get num_layers from config if available
model_config = od_config.tf_model_config
if model_config is not None and hasattr(model_config, "num_layers"):
num_layers = model_config.num_layers
num_layers = getattr(model_config, "num_layers", 28) if model_config is not None else 28

self.od_config = od_config
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.parallel_config = od_config.parallel_config

# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
pooled_projection_dim = 2 * 2 * condition_dim
inner_dim = num_attention_heads * attention_head_dim

tp_size = self.parallel_config.tensor_parallel_size
ffn_hidden_dim = inner_dim * 4

supported_tp_candidates = validate_glm_image_tp_constraints(
dim=inner_dim,
num_heads=num_attention_heads,
ffn_hidden_dim=ffn_hidden_dim,
tensor_parallel_size=tp_size,
)

logger.info_once(
"GLM-Image init: dim=%d num_heads=%d head_dim=%d ffn_hidden_dim=%d tp=%d (supported_tp=%s)",
inner_dim,
num_attention_heads,
attention_head_dim,
ffn_hidden_dim,
tp_size,
tuple(supported_tp_candidates),
)

# 1. RoPE
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)

# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
self.glyph_projector = GlmImageFeedForward(
dim=text_embed_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="gelu"
)
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
self.prior_projector = GlmImageFeedForward(
dim=inner_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="linear-silu"
)

self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
Expand All @@ -595,7 +797,13 @@ def __init__(
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
GlmImageTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
time_embed_dim,
ffn_hidden_dim=ffn_hidden_dim,
)
for _ in range(num_layers)
]
)
Expand Down
Loading