diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 88ea4e5e823..11836bbd7c1 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -72,6 +72,7 @@ The following table shows which models are currently supported by each accelerat | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | **FLUX.1-Kontext-dev** | `black-forest-labs/FLUX.1-Kontext-dev` | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | +| **GLM-Image** | `zai-org/GLM-Image` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ### VideoGen diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py index cf6ec8f03ab..8b129ce2a55 100644 --- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -1,17 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from collections.abc import Iterable from enum import Enum from typing import Any import torch import torch.nn as nn -from diffusers.models.attention import FeedForward +import torch.nn.functional as F from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings from vllm.logger import init_logger -from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.layer import Attention @@ -21,6 +26,72 @@ logger = init_logger(__name__) +def _positive_divisors(n: int) -> set[int]: + if n <= 0: + return set() + divs: set[int] = set() + for d in range(1, int(math.isqrt(n)) + 1): + if n % d == 0: + divs.add(d) + divs.add(n // d) + return divs + + +def validate_glm_image_tp_constraints( + *, + dim: int, + num_heads: int, + ffn_hidden_dim: int, + tensor_parallel_size: int, +) -> list[int]: + """Validate GLM-Image TP constraints without requiring a distributed context. + + Args: + dim: Model hidden dimension + num_heads: Number of attention heads + ffn_hidden_dim: FFN hidden dimension + tensor_parallel_size: TP size to validate against + + Returns: + List of supported TP candidates + + Raises: + ValueError: If constraints are violated + """ + tp_size = int(tensor_parallel_size) + if tp_size <= 0: + raise ValueError(f"tensor_parallel_size must be > 0, got {tp_size}") + + if dim % tp_size != 0: + supported = sorted(_positive_divisors(dim)) + raise ValueError( + f"GLM-Image requires dim % tensor_parallel_size == 0, " + f"but got dim={dim}, tp={tp_size}. " + f"Supported tp candidates by dim: {supported}" + ) + + if num_heads % tp_size != 0: + supported = sorted(_positive_divisors(num_heads)) + raise ValueError( + f"GLM-Image requires num_heads % tensor_parallel_size == 0, " + f"but got num_heads={num_heads}, tp={tp_size}. " + f"Supported tp candidates by num_heads: {supported}" + ) + + if ffn_hidden_dim % tp_size != 0: + supported = sorted(_positive_divisors(ffn_hidden_dim)) + raise ValueError( + f"GLM-Image requires ffn_hidden_dim % tensor_parallel_size == 0, " + f"but got ffn_hidden_dim={ffn_hidden_dim}, tp={tp_size}. " + f"Supported tp candidates by ffn_hidden_dim: {supported}" + ) + + supported_tp_candidates = sorted( + _positive_divisors(num_heads) & _positive_divisors(dim) & _positive_divisors(ffn_hidden_dim) + ) + return supported_tp_candidates + + class GlmImageImageProjector(nn.Module): """Projects latent image patches to transformer hidden dimension.""" @@ -330,17 +401,17 @@ def __init__( ): super().__init__() self.dim = dim - self.num_heads = num_heads + self.total_num_heads = num_heads self.head_dim = head_dim - self.inner_dim = num_heads * head_dim # QKV projection (fused for efficiency) self.to_qkv = QKVParallelLinear( hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, - disable_tp=True, + total_num_kv_heads=num_heads, bias=True, + return_bias=False, ) # QK normalization (LayerNorm, not RMSNorm for GLM-Image) @@ -348,17 +419,26 @@ def __init__( self.norm_k = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps) # Output projection - self.to_out = nn.Sequential( - nn.Linear(self.inner_dim, dim, bias=out_bias), - nn.Dropout(0.0), + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + dim, + dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ), + nn.Dropout(0.0), + ] ) # Attention self.attn = Attention( - num_heads=num_heads, + num_heads=self.to_qkv.num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False, + num_kv_heads=self.to_qkv.num_kv_heads, ) def forward( @@ -390,13 +470,15 @@ def forward( hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) # QKV projection - qkv, _ = self.to_qkv(hidden_states_combined) - query, key, value = qkv.chunk(3, dim=-1) + qkv = self.to_qkv(hidden_states_combined) + q_size = self.to_qkv.num_heads * self.head_dim + kv_size = self.to_qkv.num_kv_heads * self.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) # Reshape: [B, S, H*D] -> [B, S, H, D] - query = query.unflatten(-1, (self.num_heads, -1)) - key = key.unflatten(-1, (self.num_heads, -1)) - value = value.unflatten(-1, (self.num_heads, -1)) + query = query.unflatten(-1, (self.to_qkv.num_heads, -1)) + key = key.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + value = value.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) # QK normalization query = self.norm_q(query).to(dtype=dtype) @@ -431,7 +513,8 @@ def forward( hidden_states_out = hidden_states_out.to(dtype) # Output projection - hidden_states_out = self.to_out(hidden_states_out) + for module in self.to_out: + hidden_states_out = module(hidden_states_out) # Split back to text and image encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] @@ -440,6 +523,100 @@ def forward( return hidden_states_out, encoder_hidden_states_out +class ColumnParallelGELU(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + *, + approximate: str = "none", + bias: bool = True, + ): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +class ColumnParallelSiLU(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + *, + bias: bool = True, + ): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.silu(x) + + +class GlmImageFeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + inner_dim: int | None = None, + bias: bool = True, + activation_fn: str = "gelu", + ): + super().__init__() + inner_dim = inner_dim or int(dim * mult) + dim_out = dim_out or dim + + if activation_fn == "linear-silu": + layers: list[nn.Module] = [ + ColumnParallelSiLU(dim, inner_dim, bias=bias), + nn.Identity(), + RowParallelLinear( + inner_dim, + dim_out, + bias=bias, + input_is_parallel=True, + return_bias=False, + ), + ] + else: + approximate = "tanh" if activation_fn == "gelu-approximate" else "none" + layers = [ + ColumnParallelGELU(dim, inner_dim, approximate=approximate, bias=bias), + nn.Identity(), + RowParallelLinear( + inner_dim, + dim_out, + bias=bias, + input_is_parallel=True, + return_bias=False, + ), + ] + + self.net = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + class GlmImageTransformerBlock(nn.Module): """Single transformer block for GLM-Image.""" @@ -449,6 +626,7 @@ def __init__( num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512, + ffn_hidden_dim: int | None = None, ) -> None: super().__init__() @@ -463,7 +641,7 @@ def __init__( # 2. Feedforward self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.ff = GlmImageFeedForward(dim=dim, dim_out=dim, inner_dim=ffn_hidden_dim, activation_fn="gelu-approximate") def forward( self, @@ -564,26 +742,50 @@ def __init__( # Get num_layers from config if available model_config = od_config.tf_model_config - if model_config is not None and hasattr(model_config, "num_layers"): - num_layers = model_config.num_layers + num_layers = getattr(model_config, "num_layers", 28) if model_config is not None else 28 self.od_config = od_config self.patch_size = patch_size self.in_channels = in_channels self.out_channels = out_channels + self.parallel_config = od_config.parallel_config # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords pooled_projection_dim = 2 * 2 * condition_dim inner_dim = num_attention_heads * attention_head_dim + tp_size = self.parallel_config.tensor_parallel_size + ffn_hidden_dim = inner_dim * 4 + + supported_tp_candidates = validate_glm_image_tp_constraints( + dim=inner_dim, + num_heads=num_attention_heads, + ffn_hidden_dim=ffn_hidden_dim, + tensor_parallel_size=tp_size, + ) + + logger.info_once( + "GLM-Image init: dim=%d num_heads=%d head_dim=%d ffn_hidden_dim=%d tp=%d (supported_tp=%s)", + inner_dim, + num_attention_heads, + attention_head_dim, + ffn_hidden_dim, + tp_size, + tuple(supported_tp_candidates), + ) + # 1. RoPE self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) # 2. Patch & Text-timestep embedding self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) - self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.glyph_projector = GlmImageFeedForward( + dim=text_embed_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="gelu" + ) self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) - self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") + self.prior_projector = GlmImageFeedForward( + dim=inner_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="linear-silu" + ) self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( embedding_dim=time_embed_dim, @@ -595,7 +797,13 @@ def __init__( # 3. Transformer blocks self.transformer_blocks = nn.ModuleList( [ - GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + GlmImageTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + time_embed_dim, + ffn_hidden_dim=ffn_hidden_dim, + ) for _ in range(num_layers) ] )