diff --git a/vllm_omni/diffusion/models/glm_image/__init__.py b/vllm_omni/diffusion/models/glm_image/__init__.py new file mode 100644 index 00000000000..fc8256d8de6 --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GLM Image diffusion model components.""" + +from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageKVCache, + GlmImageTransformer2DModel, +) +from vllm_omni.diffusion.models.glm_image.pipeline_glm_image import ( + GlmImagePipeline, + get_glm_image_post_process_func, + get_glm_image_pre_process_func, +) + +__all__ = [ + "GlmImageKVCache", + "GlmImagePipeline", + "GlmImageTransformer2DModel", + "get_glm_image_post_process_func", + "get_glm_image_pre_process_func", +] diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py new file mode 100644 index 00000000000..4aceb4ebfff --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -0,0 +1,796 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Any, Enum + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import GlmImageCombinedTimestepSizeEmbeddings +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.cache.base import CachedTransformer +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +logger = init_logger(__name__) + + +class GlmImageImageProjector(nn.Module): + """Projects latent image patches to transformer hidden dimension.""" + + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + # Reshape: [B, C, H, W] -> [B, H', W', C*p*p] -> [B, H'*W', C*p*p] + hidden_states = hidden_states.reshape( + batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class GlmImageRotaryPosEmbed(nn.Module): + """Rotary positional embedding for 2D image patches.""" + + def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(height, device=hidden_states.device) + w_seq = torch.arange(width, device=hidden_states.device) + h_inv_freq = h_inv_freq.to(hidden_states.device) + w_inv_freq = w_inv_freq.to(hidden_states.device) + + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + # Create position matrices: [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1).expand(height, width, -1) + freqs_w = freqs_w.unsqueeze(0).expand(height, width, -1) + + # Concatenate: [height, width, dim//2] -> [height, width, dim] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = torch.cat([freqs, freqs], dim=-1) + freqs = freqs.reshape(height * width, -1) + return (freqs.cos(), freqs.sin()) + + +class GlmImageAdaLayerNormZero(nn.Module): + """Adaptive LayerNorm with zero initialization for both image and text streams.""" + + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) + + emb = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class GlmImageAdaLayerNormContinuous(nn.Module): + """Final AdaLN for output projection (no activation before Linear).""" + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ): + super().__init__() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # NO SiLU here + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class KVCacheMode(Enum): + """Mode for KV cache operations. + + - WRITE: Store the K/V tensors from condition images + - READ: Concatenate cached K/V with current K/V + - SKIP: Do not use cache (pass-through) + """ + + WRITE = "write" + READ = "read" + SKIP = "skip" + + +class GlmImageLayerKVCache: + """KV cache for a single attention layer. + + Stores key and value tensors for image editing. The cache accumulates + KV pairs during write mode and provides them during read mode. + + Shape convention (vllm-omni): + key/value: [batch_size, seq_length, num_heads, head_dim] + """ + + def __init__(self): + self.k_cache: torch.Tensor | None = None + self.v_cache: torch.Tensor | None = None + + def store(self, key: torch.Tensor, value: torch.Tensor) -> None: + """Store or accumulate KV tensors. + + If cache is empty, stores the tensors directly. + If cache is not empty, concatenates new tensors along seq_length dim. + + Args: + key: Key tensor of shape [B, S, H, D] + value: Value tensor of shape [B, S, H, D] + """ + if self.k_cache is None: + self.k_cache = key + self.v_cache = value + else: + # Concatenate along sequence dimension (dim=1 for [B, S, H, D]) + self.k_cache = torch.cat([self.k_cache, key], dim=1) + self.v_cache = torch.cat([self.v_cache, value], dim=1) + + def get(self) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Get cached KV tensors. + + Returns: + Tuple of (k_cache, v_cache), both may be None if cache is empty. + """ + return self.k_cache, self.v_cache + + def clear(self) -> None: + """Clear the cache.""" + self.k_cache = None + self.v_cache = None + + @property + def is_empty(self) -> bool: + """Check if cache is empty.""" + return self.k_cache is None + + def __repr__(self) -> str: + if self.is_empty: + return "GlmImageLayerKVCache(empty)" + return f"GlmImageLayerKVCache(k_shape={self.k_cache.shape}, v_shape={self.v_cache.shape})" + + +class GlmImageKVCache: + """Container for all layers' KV caches. + + Manages KV cache for all transformer layers in GLM-Image model. + Provides a unified interface for setting mode and clearing cache. + + Args: + num_layers: Number of transformer layers in the model. + + Example: + kv_cache = GlmImageKVCache(num_layers=28) + kv_cache.set_mode(KVCacheMode.WRITE) + # ... process condition image ... + kv_cache.set_mode(KVCacheMode.READ) + # ... process target image ... + kv_cache.clear() + """ + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + self._mode: KVCacheMode | None = None + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + """Get cache for a specific layer. + + Args: + layer_idx: Index of the layer (0-indexed). + + Returns: + GlmImageLayerKVCache for the specified layer. + + Raises: + IndexError: If layer_idx is out of range. + """ + if layer_idx < 0 or layer_idx >= self.num_layers: + raise IndexError(f"Layer index {layer_idx} out of range [0, {self.num_layers})") + return self.caches[layer_idx] + + def __len__(self) -> int: + """Return number of layers.""" + return self.num_layers + + @property + def mode(self) -> KVCacheMode | None: + """Get current cache mode.""" + return self._mode + + def set_mode(self, mode: KVCacheMode | str | None) -> None: + """Set cache mode for all layers. + + Args: + mode: Cache mode (WRITE, READ, SKIP) or string ("write", "read", "skip"). + Use None to disable cache operations. + + Raises: + ValueError: If mode is an invalid string. + """ + if mode is None: + self._mode = None + elif isinstance(mode, str): + try: + self._mode = KVCacheMode(mode.lower()) + except ValueError: + raise ValueError(f"Invalid mode: '{mode}', must be one of 'write', 'read', 'skip'") + else: + self._mode = mode + + def clear(self) -> None: + """Clear cache for all layers and reset mode.""" + for cache in self.caches: + cache.clear() + self._mode = None + + @property + def is_empty(self) -> bool: + """Check if all layer caches are empty.""" + return all(cache.is_empty for cache in self.caches) + + def __repr__(self) -> str: + mode_str = self._mode.value if self._mode else "None" + return f"GlmImageKVCache(num_layers={self.num_layers}, mode={mode_str}, is_empty={self.is_empty})" + + +class GlmImageAttention(nn.Module): + """ + Joint attention for GLM-Image model using vllm-omni's optimized attention. + + This combines text and image streams for joint attention computation. + Supports KV caching for image editing workflows via external cache. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + out_bias: bool = True, + eps: float = 1e-5, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + + # QKV projection (fused for efficiency) + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=num_heads, + disable_tp=True, + bias=True, + ) + + # QK normalization (LayerNorm, not RMSNorm for GLM-Image) + self.norm_q = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps) + self.norm_k = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps) + + # Output projection + self.to_out = nn.Sequential( + nn.Linear(self.inner_dim, dim, bias=out_bias), + nn.Dropout(0.0), + ) + + # RoPE and attention + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + kv_cache: GlmImageLayerKVCache | None = None, + kv_cache_mode: KVCacheMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for joint attention. + + Args: + hidden_states: Image hidden states [B, img_seq_len, D] + encoder_hidden_states: Text hidden states [B, text_seq_len, D] + image_rotary_emb: Tuple of (cos, sin) for RoPE + attention_mask: Optional attention mask for text tokens + kv_cache: Optional layer KV cache for image editing + kv_cache_mode: Cache mode (WRITE, READ, SKIP) + + Returns: + Tuple of (image_hidden_states, text_hidden_states) + """ + dtype = encoder_hidden_states.dtype + batch_size, text_seq_length, _ = encoder_hidden_states.shape + + # Concatenate text and image: [text, image] + hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # QKV projection + qkv, _ = self.to_qkv(hidden_states_combined) + query, key, value = qkv.chunk(3, dim=-1) + + # Reshape: [B, S, H*D] -> [B, S, H, D] + query = query.unflatten(-1, (self.num_heads, -1)) + key = key.unflatten(-1, (self.num_heads, -1)) + value = value.unflatten(-1, (self.num_heads, -1)) + + # QK normalization + query = self.norm_q(query).to(dtype=dtype) + key = self.norm_k(key).to(dtype=dtype) + + # Apply RoPE only to image tokens (not text tokens) + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + # Only apply RoPE to image part (after text_seq_length) + query_img = query[:, text_seq_length:, :, :] + key_img = key[:, text_seq_length:, :, :] + query_img = self.rope(query_img, cos, sin) + key_img = self.rope(key_img, cos, sin) + query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) + key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) + + # Handle KV cache for image editing + if kv_cache is not None and kv_cache_mode is not None: + if kv_cache_mode == KVCacheMode.WRITE: + kv_cache.store(key, value) + elif kv_cache_mode == KVCacheMode.READ: + k_cached, v_cached = kv_cache.get() + if k_cached is not None: + key = torch.cat([k_cached, key], dim=1) + value = torch.cat([v_cached, value], dim=1) + # KVCacheMode.SKIP: do nothing + + # Attention computation + hidden_states_out = self.attn(query, key, value) + hidden_states_out = hidden_states_out.flatten(2, 3) + hidden_states_out = hidden_states_out.to(dtype) + + # Output projection + hidden_states_out = self.to_out(hidden_states_out) + + # Split back to text and image + encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] + hidden_states_out = hidden_states_out[:, text_seq_length:, :] + + return hidden_states_out, encoder_hidden_states_out + + +class GlmImageTransformerBlock(nn.Module): + """Single transformer block for GLM-Image.""" + + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ) -> None: + super().__init__() + + # 1. Attention with AdaLN + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) + self.attn = GlmImageAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + ) + + # 2. Feedforward + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + kv_cache: GlmImageLayerKVCache | None = None, + kv_cache_mode: KVCacheMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for transformer block. + + Args: + hidden_states: Image hidden states + encoder_hidden_states: Text hidden states + temb: Timestep embedding + image_rotary_emb: RoPE embeddings + attention_mask: Text attention mask + attention_kwargs: Additional attention arguments + kv_cache: Layer-specific KV cache for image editing + kv_cache_mode: Cache mode (WRITE, READ, SKIP) + + Returns: + Tuple of (image_hidden_states, text_hidden_states) + """ + # 1. Timestep conditioning via AdaLN + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, temb) + + # 2. Attention + attn_hidden_states, attn_encoder_hidden_states = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + kv_cache_mode=kv_cache_mode, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) + + # 3. Feedforward + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class GlmImageTransformer2DModel(CachedTransformer): + """ + GLM-Image Transformer model for 2D image generation. + + This is the vllm-omni optimized version of the GLM-Image DiT model. + + Args: + od_config: OmniDiffusionConfig containing model configuration. + patch_size: Size of image patches. + in_channels: Number of input channels (latent channels). + num_layers: Number of transformer blocks. + attention_head_dim: Dimension of each attention head. + num_attention_heads: Number of attention heads. + out_channels: Number of output channels. + text_embed_dim: Dimension of text embeddings. + time_embed_dim: Dimension of timestep embeddings. + condition_dim: Dimension of conditioning embeddings. + prior_vq_quantizer_codebook_size: Size of prior VQ codebook. + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + out_channels: int = 16, + text_embed_dim: int = 1472, + time_embed_dim: int = 512, + condition_dim: int = 256, + prior_vq_quantizer_codebook_size: int = 16384, + ): + super().__init__() + + # Get num_layers from config if available + model_config = od_config.tf_model_config + if model_config is not None and hasattr(model_config, "num_layers"): + num_layers = model_config.num_layers + + self.od_config = od_config + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords + pooled_projection_dim = 2 * 2 * condition_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. RoPE + self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + + # 2. Patch & Text-timestep embedding + self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) + self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") + + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=time_embed_dim, + ) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + for _ in range(num_layers) + ] + ) + + # 4. Output projection + self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + prior_token_id: torch.Tensor, + prior_token_drop: torch.Tensor, + timestep: torch.LongTensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + kv_cache: GlmImageKVCache | None = None, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + Forward pass of the GLM-Image Transformer. + + Args: + hidden_states: Input latent tensor of shape [B, C, H, W]. + encoder_hidden_states: Text embeddings of shape [B, S, D]. + prior_token_id: Prior VQ token IDs. + prior_token_drop: Mask for dropping prior tokens (CFG). + timestep: Diffusion timestep. + target_size: Target image size for conditioning. + crop_coords: Crop coordinates for conditioning. + attention_kwargs: Additional attention arguments. + return_dict: Whether to return a dataclass. + attention_mask: Optional attention mask for text tokens. + image_rotary_emb: Pre-computed rotary embeddings. + kv_cache: Optional KV cache for image editing. When provided, + the cache's mode determines behavior: + - WRITE: Store KV from condition images + - READ: Use cached KV during generation + - SKIP: No caching (same as None) + + Returns: + Output tensor or Transformer2DModelOutput. + """ + batch_size, num_channels, height, width = hidden_states.shape + + # Get KV cache mode + kv_cache_mode = kv_cache.mode if kv_cache is not None else None + + # 1. RoPE + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) + # Move to correct device + image_rotary_emb = ( + image_rotary_emb[0].to(hidden_states.device), + image_rotary_emb[1].to(hidden_states.device), + ) + + # 2. Patch & Timestep embeddings + p = self.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = self.image_projector(hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) + + # Prior embedding with dropout + prior_embedding = self.prior_token_embedding(prior_token_id) + prior_embedding[prior_token_drop] *= 0.0 + prior_hidden_states = self.prior_projector(prior_embedding) + hidden_states = hidden_states + prior_hidden_states + + # Timestep conditioning + temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) + + # 3. Transformer blocks + for layer_idx, block in enumerate(self.transformer_blocks): + # Get layer-specific KV cache if available + layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None + + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_cache=layer_kv_cache, + kv_cache_mode=kv_cache_mode, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W] + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from pretrained checkpoint. + + This method handles the mapping from diffusers weight names to vllm-omni weight names, + especially for fused QKV projections. + """ + stacked_params_mapping = [ + # Fused QKV projection: to_q, to_k, to_v -> to_qkv + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + ] + + params_dict = dict(self.named_parameters()) + + # Also include buffers (for any beta/eps parameters) + for name, buffer in self.named_buffers(): + params_dict[name] = buffer + + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Handle fused QKV projections + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + # Map diffusers name to vllm-omni name + name = name.replace(weight_name, param_name) + + if name not in params_dict: + logger.warning(f"Skipping weight {name} - not found in model") + break + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Standard weight loading (not fused) + if name not in params_dict: + logger.warning(f"Skipping weight {name} - not found in model") + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + + return loaded_params + + def create_kv_cache(self) -> GlmImageKVCache: + """ + Create a KV cache for image editing. + + Returns a new GlmImageKVCache instance sized for this model's + number of transformer layers. Use this for image editing workflows. + + Example: + kv_cache = transformer.create_kv_cache() + kv_cache.set_mode("write") + transformer(condition_image, kv_cache=kv_cache) + kv_cache.set_mode("read") + for t in timesteps: + transformer(noisy_target, kv_cache=kv_cache) + kv_cache.clear() + + Returns: + GlmImageKVCache instance with correct number of layers. + """ + return GlmImageKVCache(num_layers=len(self.transformer_blocks)) + + @property + def num_layers(self) -> int: + """Return number of transformer layers.""" + return len(self.transformer_blocks) + + @property + def dtype(self) -> torch.dtype: + """Return dtype of model parameters.""" + return next(self.parameters()).dtype diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py new file mode 100644 index 00000000000..599eb2cdd8e --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -0,0 +1,965 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GlmImagePipeline implementation for vLLM-Omni. + +This pipeline implements GLM-Image text-to-image generation with: +- AR stage: GlmImageForConditionalGeneration generates prior tokens +- DiT stage: GlmImageTransformer2DModel performs diffusion denoising +- VAE: AutoencoderKL decodes latents to images +""" + +from __future__ import annotations + +import inspect +import json +import logging +import os +import re +from collections.abc import Iterable +from math import sqrt + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import ( + ByT5Tokenizer, + GlmImageForConditionalGeneration, + GlmImageProcessor, + T5EncoderModel, +) + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageKVCache, + GlmImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +def get_glm_image_post_process_func(od_config: OmniDiffusionConfig): + """Get post-processing function for GLM-Image pipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + block_out_channels = vae_config.get("block_out_channels", [128, 256, 512, 512]) + vae_scale_factor = 2 ** (len(block_out_channels) - 1) + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + """Calculate timestep shift based on image sequence length.""" + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps. + Handles custom timesteps and sigmas schedules. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + # Both provided - check if scheduler supports both + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +def retrieve_latents( + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", +) -> torch.Tensor: + """Extract latents from VAE encoder output.""" + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImagePipeline(nn.Module): + """ + GLM-Image Pipeline for text-to-image and image-to-image generation. + + This pipeline integrates: + - AR model (GlmImageForConditionalGeneration): Generates prior image tokens + - Text encoder (T5EncoderModel): Encodes glyph/text embeddings + - DiT model (GlmImageTransformer2DModel): Diffusion transformer + - VAE (AutoencoderKL): Encodes/decodes images to/from latent space + + The pipeline flow: + 1. AR generates prior_token_ids from text prompt + 2. T5 encodes glyph text for text rendering + 3. DiT performs iterative denoising conditioned on prior tokens + 4. VAE decodes final latents to image + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config + self.device = get_local_device() + + model = od_config.model + local_files_only = os.path.exists(model) + + if local_files_only: + model_path = model + else: + model_path = download_weights_from_hf_specific(model, od_config.revision, ["*"]) + + # Load scheduler + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model_path, subfolder="scheduler", local_files_only=True + ) + + # Load AR model (vision_language_encoder) + logger.info("Loading GlmImageForConditionalGeneration (AR model)...") + self.vision_language_encoder = GlmImageForConditionalGeneration.from_pretrained( + model_path, + subfolder="vision_language_encoder", + local_files_only=True, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.vision_language_encoder.eval() + + # Load processor for AR model + self.processor = GlmImageProcessor.from_pretrained(model_path, subfolder="processor", local_files_only=True) + + # Load text encoder (T5 for glyph embeddings) + logger.info("Loading T5EncoderModel (glyph encoder)...") + self.text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + local_files_only=True, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.text_encoder.eval() + + # Load tokenizer for glyph encoding + self.tokenizer = ByT5Tokenizer.from_pretrained(model_path, subfolder="tokenizer", local_files_only=True) + + # Load VAE + logger.info("Loading AutoencoderKL (VAE)...") + self.vae = AutoencoderKL.from_pretrained( + model_path, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16 + ).to(self.device) + self.vae.eval() + + # Load transformer (DiT) + logger.info("Loading GlmImageTransformer2DModel (DiT)...") + self.transformer = GlmImageTransformer2DModel(od_config=od_config) + + # Weight sources for DiT loading + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=od_config.revision, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + # Configure scale factors + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 128 + + # Get transformer config for patch size + self._patch_size = getattr(self.transformer, "patch_size", 2) + + # ==================== Input Validation ==================== + + def check_inputs( + self, + prompt: str | list[str] | None, + height: int | None, + width: int | None, + prompt_embeds: torch.Tensor | None = None, + ) -> None: + """Validate input arguments before generation.""" + # Check dimension alignment + multiple_of = self.vae_scale_factor * self._patch_size + if height is not None and height % multiple_of != 0: + logger.warning( + f"`height` should be divisible by {multiple_of} but is {height}. " + "Dimensions will be adjusted accordingly." + ) + if width is not None and width % multiple_of != 0: + logger.warning( + f"`width` should be divisible by {multiple_of} but is {width}. Dimensions will be adjusted accordingly." + ) + + # Check prompt/prompt_embeds mutual exclusivity + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please provide only one of the two." + ) + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.") + + # Check prompt type + if prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` must be of type `str` or `list` but is {type(prompt)}") + + # ==================== AR Stage Methods ==================== + + @staticmethod + def _build_image_grid_thw( + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + existing_grid: torch.Tensor | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + """Build image grid tensor for AR model.""" + if existing_grid is None or existing_grid.numel() == 0: + return torch.tensor( + [ + [1, token_h, token_w], + [1, prev_token_h, prev_token_w], + ], + device=device, + ) + else: + return torch.cat( + [existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], + dim=0, + ) + + @staticmethod + def _calculate_ar_generation_params( + token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool + ) -> tuple[int, int]: + """Calculate AR generation parameters.""" + large_image_tokens = token_h * token_w + small_image_tokens = prev_token_h * prev_token_w + + if is_text_to_image: + max_new_tokens = small_image_tokens + large_image_tokens + 1 + large_image_start_offset = small_image_tokens + else: + max_new_tokens = large_image_tokens + 1 + large_image_start_offset = 0 + + return max_new_tokens, large_image_start_offset + + @staticmethod + def _extract_large_image_tokens( + outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + ) -> torch.Tensor: + """Extract large image tokens from AR output.""" + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + @staticmethod + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + """Upsample token IDs by 2x using nearest neighbor interpolation.""" + token_ids = token_ids.view(1, 1, token_h, token_w) + token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( + dtype=torch.long + ) + token_ids = token_ids.view(1, -1) + return token_ids + + @staticmethod + def _build_prompt_with_shape( + prompt: str, + height: int, + width: int, + is_text_to_image: bool, + factor: int = 32, + ) -> tuple[str, int, int, int, int]: + """Build prompt with shape information for AR model.""" + token_h = height // factor + token_w = width // factor + ratio = token_h / token_w + prev_token_h = int(sqrt(ratio) * (factor // 2)) + prev_token_w = int(sqrt(1 / ratio) * (factor // 2)) + + if is_text_to_image: + expanded_prompt = f"{prompt}{token_h} {token_w}{prev_token_h} {prev_token_w}" + else: + expanded_prompt = f"{prompt}{token_h} {token_w}" + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + + @torch.inference_mode() + def generate_prior_tokens( + self, + prompt: str, + height: int, + width: int, + image: list[PIL.Image.Image] | None = None, + factor: int = 32, + ) -> tuple[torch.Tensor, torch.Tensor | None, int, int]: + """ + Generate prior tokens using the AR model. + + Args: + prompt: Text prompt for generation + height: Target image height + width: Target image width + image: Optional condition images for image-to-image + factor: Token factor (default 32) + + Returns: + Tuple of (prior_token_ids, prior_token_image_ids, pixel_height, pixel_width) + """ + device = self.vision_language_encoder.device + height = (height // factor) * factor + width = (width // factor) * factor + is_text_to_image = image is None or len(image) == 0 + + expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape( + prompt, height, width, is_text_to_image + ) + + # Build message content + content = [] + if image is not None: + for img in image: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": expanded_prompt}) + messages = [{"role": "user", "content": content}] + + # Apply chat template + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + # Build image grid + existing_grid = inputs.get("image_grid_thw") + inputs["image_grid_thw"] = self._build_image_grid_thw( + token_h, + token_w, + prev_h, + prev_w, + existing_grid=existing_grid if not is_text_to_image else None, + device=device, + ) + + max_new_tokens, large_image_offset = self._calculate_ar_generation_params( + token_h, token_w, prev_h, prev_w, is_text_to_image + ) + large_image_tokens = token_h * token_w + + inputs = inputs.to(device) + input_length = inputs["input_ids"].shape[-1] + + # Process condition images if provided + prior_token_image_ids = None + if image is not None and existing_grid is not None: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], existing_grid + ) + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, existing_grid + ) + + # Generate with AR model + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + # Extract and upsample tokens + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs, input_length, large_image_offset, large_image_tokens + ) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + + return prior_token_ids, prior_token_image_ids + + # ==================== Text Encoding Methods ==================== + + def get_glyph_texts(self, prompt: str | list[str]) -> list[str]: + """Extract text within quotes for glyph rendering.""" + prompt = prompt[0] if isinstance(prompt, list) else prompt + ocr_texts = ( + re.findall(r"'([^']*)'", prompt) + + re.findall(r"“([^“”]*)”", prompt) + + re.findall(r'"([^"]*)"', prompt) + + re.findall(r"「([^「」]*)」", prompt) + ) + return ocr_texts + + def _get_glyph_embeds( + self, + prompt: str | list[str], + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + """Get glyph embeddings from T5 encoder for text rendering.""" + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + glyph_texts = self.get_glyph_texts(prompt) + input_ids = self.tokenizer( + glyph_texts if len(glyph_texts) > 0 else [""], + max_length=max_sequence_length, + truncation=True, + ).input_ids + + # Pad to even length + input_ids = [[self.tokenizer.pad_token_id] * ((len(ids) + 1) % 2) + ids for ids in input_ids] + max_length = max(len(ids) for ids in input_ids) + + attention_mask = torch.tensor( + [[1] * len(ids) + [0] * (max_length - len(ids)) for ids in input_ids], + device=device, + ) + input_ids = torch.tensor( + [ids + [self.tokenizer.pad_token_id] * (max_length - len(ids)) for ids in input_ids], + device=device, + ) + + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 2048, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Encode prompt into glyph embeddings for text rendering.""" + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = [""] * batch_size + negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # ==================== Latent Preparation ==================== + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """Prepare random noise latents.""" + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError(f"Passed {len(generator)} generators but batch size is {batch_size}.") + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def diffuse( + self, + latents: torch.Tensor, + prior_token_id: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + timesteps: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + guidance_scale: float, + do_classifier_free_guidance: bool, + kv_caches: GlmImageKVCache | None = None, + ) -> torch.Tensor: + """ + Denoising loop for diffusion process with CFG-Parallel support. + + Args: + latents: Initial noise latents + prior_token_id: Prior tokens generated by AR model + prompt_embeds: Encoded positive prompt embeddings (glyph embeddings) + negative_prompt_embeds: Encoded negative prompt embeddings + timesteps: Denoising timesteps + target_size: Target image size tensor [[height, width]] + crop_coords: Crop coordinates tensor + guidance_scale: CFG scale + do_classifier_free_guidance: Whether to apply CFG + kv_caches: Optional KV cache for Image Edit mode + + Returns: + Denoised latents ready for VAE decode + """ + # Prepare conditional/unconditional drop flags + prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool) + + transformer_dtype = self.transformer.dtype + + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_classifier_free_guidance and get_classifier_free_guidance_world_size() > 1 + + for i, t in enumerate(timesteps): + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) - 1 + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_rank == 0: + # Rank 0: Compute positive (conditional) prediction + local_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_caches=kv_caches, + return_dict=False, + )[0].float() + else: + # Rank 1: Compute negative (unconditional) prediction + local_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_caches=kv_caches, + return_dict=False, + )[0].float() + + # All-gather predictions from all ranks + gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + + if cfg_rank == 0: + # Rank 0: Combine predictions and apply CFG + noise_pred_cond = gathered[0] + noise_pred_uncond = gathered[1] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Broadcast updated latents to all ranks + cfg_group.broadcast(latents, src=0) + + else: + # Sequential CFG (single GPU or no CFG) + # Conditional forward pass + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_caches=kv_caches, + return_dict=False, + )[0].float() + + if do_classifier_free_guidance: + # Unconditional forward pass + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_caches=kv_caches, + return_dict=False, + )[0].float() + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + return latents + + # ==================== Main Forward Pass ==================== + + def _prepare_condition_image_kv_cache( + self, + condition_images: list[torch.Tensor], + prior_token_image_ids: list[torch.Tensor], + prompt_embeds: torch.Tensor, + generator: torch.Generator | None = None, + ) -> GlmImageKVCache: + """ + Prepare KV cache by running condition images through transformer at timestep 0. + + This is used for Image Edit mode where we need to cache the condition image's + KV states for cross-attention during denoising. + + Args: + condition_images: List of preprocessed condition images + prior_token_image_ids: Prior token IDs for each condition image from AR model + prompt_embeds: Prompt embeddings (used to get dtype) + generator: Optional random generator + + Returns: + GlmImageKVCache with cached KV states from condition images + """ + kv_caches = self.transformer.create_kv_cache() + kv_caches.set_mode("write") + + # Prepare VAE normalization parameters + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(device=self.device, dtype=prompt_embeds.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(device=self.device, dtype=prompt_embeds.dtype) + ) + + # Process each condition image through transformer to populate KV cache + for condition_image, condition_prior_token_id in zip(condition_images, prior_token_image_ids): + condition_image = condition_image.to(device=self.device, dtype=prompt_embeds.dtype) + + # Encode condition image to latent space + # Use argmax (mode) for deterministic encoding of condition images + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + # Run forward pass at timestep 0 to cache KV states + # Empty encoder_hidden_states since we only want to cache image features + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], + prior_token_id=condition_prior_token_id, + prior_token_drop=torch.full_like(condition_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=self.device), + target_size=torch.tensor([condition_image.shape[-2:]], device=self.device, dtype=prompt_embeds.dtype), + crop_coords=torch.zeros((1, 2), device=self.device, dtype=prompt_embeds.dtype), + kv_caches=kv_caches, + return_dict=False, + ) + + return kv_caches + + def _preprocess_condition_images( + self, + images: list[PIL.Image.Image] | PIL.Image.Image | None, + ) -> tuple[list[torch.Tensor] | None, int | None, int | None]: + """ + Preprocess condition images for Image Edit mode. + + Args: + images: Input images (PIL or list of PIL) + + Returns: + Tuple of (preprocessed_images, height, width) + """ + if images is None: + return None, None, None + + if not isinstance(images, list): + images = [images] + + preprocessed = [] + height, width = None, None + + for img in images: + if isinstance(img, PIL.Image.Image): + img_h, img_w = img.size[::-1] + else: + img_h, img_w = img.shape[:2] + + # Align to multiple of vae_scale_factor * patch_size + multiple_of = self.vae_scale_factor * self._patch_size + img_h = (img_h // multiple_of) * multiple_of + img_w = (img_w // multiple_of) * multiple_of + + processed = self.image_processor.preprocess(img, height=img_h, width=img_w) + preprocessed.append(processed) + + # Use first image dimensions as default + if height is None: + height, width = img_h, img_w + + return preprocessed, height, width + + @torch.inference_mode() + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + """ + Main generation forward pass. + + Args: + req: OmniDiffusionRequest with generation parameters + + Returns: + DiffusionOutput containing generated image + """ + prompt = req.prompt or "" + if isinstance(prompt, list): + prompt = prompt[0] if prompt else "" + + # Get pre-computed prompt embeddings if provided + prompt_embeds = req.prompt_embeds if isinstance(req.prompt_embeds, torch.Tensor) else None + + # Get condition images for Image Edit mode + condition_images = req.pil_image + if condition_images is not None and not isinstance(condition_images, list): + condition_images = [condition_images] + + # Preprocess condition images and get dimensions + preprocessed_images, img_height, img_width = self._preprocess_condition_images(condition_images) + is_image_edit = preprocessed_images is not None + + # Use image dimensions as default if available + height = req.height or img_height or self.default_sample_size * self.vae_scale_factor + width = req.width or img_width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.num_inference_steps or 50 + guidance_scale = req.guidance_scale or 1.5 + + # 0. Validate inputs + self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) + + batch_size = 1 + do_classifier_free_guidance = guidance_scale > 1.0 + + # Set seed if provided + generator = None + if req.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.seed) + + # 1. Generate prior tokens with AR model + logger.info("Generating prior tokens with AR model...") + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prompt=prompt, + image=condition_images, + height=height, + width=width, + ) + + # 2. Encode prompt for glyph embeddings + logger.info("Encoding prompt...") + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=1, + prompt_embeds=prompt_embeds, + device=self.device, + dtype=self.transformer.dtype, + ) + + # 3. Prepare KV cache for Image Edit mode + kv_caches = None + if is_image_edit and prior_token_image_ids is not None: + logger.info("Preparing KV cache for Image Edit mode...") + kv_caches = self._prepare_condition_image_kv_cache( + condition_images=preprocessed_images, + prior_token_image_ids=prior_token_image_ids, + prompt_embeds=prompt_embeds, + generator=generator, + ) + # Switch to read mode for denoising + kv_caches.set_mode("read") + + # 4. Prepare latents + latent_channels = self.transformer.in_channels + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=self.device, + generator=generator, + ) + + # 5. Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (self._patch_size**2) + timesteps_array = np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1] + timesteps_array = timesteps_array.astype(np.int64).astype(np.float32) + sigmas = timesteps_array / self.scheduler.config.num_train_timesteps + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, self.device, timesteps_array.tolist(), sigmas.tolist(), mu=mu + ) + + # 6. Prepare conditioning tensors + target_size = torch.tensor([[height, width]], dtype=prompt_embeds.dtype, device=self.device) + crop_coords = torch.zeros((1, 2), dtype=prompt_embeds.dtype, device=self.device) + + # 7. Denoising loop with CFG-parallel support + logger.info(f"Starting denoising loop with {num_inference_steps} steps...") + latents = self.diffuse( + latents=latents, + prior_token_id=prior_token_id, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + timesteps=timesteps, + target_size=target_size, + crop_coords=crop_coords, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + kv_caches=kv_caches, + ) + + # 8. VAE decode + logger.info("Decoding latents with VAE...") + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + + # 9. Post-process + image = self.image_processor.postprocess(image, output_type="pil")[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load transformer weights.""" + # Filter weights for transformer only + transformer_weights = ( + (name.replace("transformer.", ""), weight) for name, weight in weights if name.startswith("transformer.") + ) + return self.transformer.load_weights(transformer_weights) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 7c57ac3a876..ff29be67345 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -29,6 +29,11 @@ "pipeline_qwen_image_layered", "QwenImageLayeredPipeline", ), + "GlmImagePipeline": ( + "glm_image", + "pipeline_glm_image", + "GlmImagePipeline", + ), "ZImagePipeline": ( "z_image", "pipeline_z_image", @@ -111,6 +116,7 @@ def initialize_model( "QwenImagePipeline": "get_qwen_image_post_process_func", "QwenImageEditPipeline": "get_qwen_image_edit_post_process_func", "QwenImageEditPlusPipeline": "get_qwen_image_edit_plus_post_process_func", + "GlmImagePipeline": "get_glm_image_post_process_func", "ZImagePipeline": "get_post_process_func", "OvisImagePipeline": "get_ovis_image_post_process_func", "WanPipeline": "get_wan22_post_process_func",