diff --git a/tests/diffusion/models/glm_image/test_glm_image_sp.py b/tests/diffusion/models/glm_image/test_glm_image_sp.py new file mode 100644 index 0000000000..1b1c8d7a75 --- /dev/null +++ b/tests/diffusion/models/glm_image/test_glm_image_sp.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for GLM-Image Sequence Parallelism support.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from vllm_omni.diffusion.data import DiffusionParallelConfig + + +@pytest.fixture(scope="function", autouse=True) +def setup_sp_groups(): + """Set up SP and TP groups for each test function.""" + with patch("vllm_omni.diffusion.distributed.parallel_state.get_sp_group") as mock_get_sp_group: + with patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=1): + with patch("vllm.distributed.parallel_state.get_tp_group") as mock_get_tp_group: + mock_sp_group = MagicMock() + mock_sp_group.world_size = 4 + mock_get_sp_group.return_value = mock_sp_group + + mock_tp_group = MagicMock() + mock_tp_group.world_size = 1 + mock_get_tp_group.return_value = mock_tp_group + yield + + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_glm_image_sp_plan_defined(): + """Test that _sp_plan is properly defined on GlmImageTransformer2DModel.""" + from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageTransformer2DModel, + ) + + assert hasattr(GlmImageTransformer2DModel, "_sp_plan") + plan = GlmImageTransformer2DModel._sp_plan + assert plan is not None + + # Verify plan structure + assert "prepare" in plan + assert "proj_out" in plan + + +def test_glm_image_sp_plan_valid(): + """Validate _sp_plan structure.""" + from vllm_omni.diffusion.distributed.sp_plan import validate_sp_plan + from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageTransformer2DModel, + ) + + plan = GlmImageTransformer2DModel._sp_plan + validate_sp_plan(plan) + + +def test_glm_image_prepare_module_exists(): + """Test that GlmImagePrepare module exists.""" + from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImagePrepare, + ) + + assert GlmImagePrepare is not None + + +def test_glm_image_attention_accepts_parallel_config(): + """Test that GlmImageAttention accepts parallel_config parameter.""" + from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageAttention, + ) + + parallel_config = DiffusionParallelConfig( + ulysses_degree=2, + ring_degree=2, + tensor_parallel_size=1, + sequence_parallel_size=4, + ) + + attn = GlmImageAttention( + dim=2560, + num_heads=64, + head_dim=40, + parallel_config=parallel_config, + ) + + assert attn.parallel_config is not None + assert attn.parallel_config.sequence_parallel_size == 4 + + +def test_glm_image_transformer_block_accepts_parallel_config(): + """Test that GlmImageTransformerBlock accepts parallel_config parameter.""" + from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageTransformerBlock, + ) + + parallel_config = DiffusionParallelConfig( + ulysses_degree=2, + ring_degree=2, + tensor_parallel_size=1, + sequence_parallel_size=4, + ) + + block = GlmImageTransformerBlock( + dim=2560, + num_attention_heads=64, + attention_head_dim=40, + time_embed_dim=512, + parallel_config=parallel_config, + ) + + assert block.attn1.parallel_config is not None + assert block.attn1.parallel_config.sequence_parallel_size == 4 + + +def test_glm_image_has_sp_support(): + """Test that GLM-Image has SP support implemented.""" + from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageTransformer2DModel, + ) + + # Check that the model has parallel_config support + assert hasattr(GlmImageTransformer2DModel, "__init__") + + # Verify the model can be instantiated with SP config + + # This test just verifies the structure exists + # Actual SP testing requires multi-GPU setup + + +@pytest.mark.cuda +@pytest.mark.sp +def test_glm_image_sp_inference(): + """Test SP inference (requires multi-GPU setup).""" + pytest.skip("Requires multi-GPU SP setup") diff --git a/vllm_omni/diffusion/attention/parallel/ulysses.py b/vllm_omni/diffusion/attention/parallel/ulysses.py index 5d860b3350..326b5d4567 100644 --- a/vllm_omni/diffusion/attention/parallel/ulysses.py +++ b/vllm_omni/diffusion/attention/parallel/ulysses.py @@ -414,10 +414,6 @@ def pre_attention( def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor: assert isinstance(ctx, _UlyssesCtx), f"Unexpected ctx type: {type(ctx)!r}" - # If we have joint tensors (Text), they were Head-Sliced. - # The main sequence (Image) was Sequence-Sliced. - # attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front'). - if ctx.joint_len > 0: joint_len = ctx.joint_len diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py index 490e0198b9..7ff42a5f00 100644 --- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -19,10 +19,16 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.cache.base import CachedTransformer -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, +) +from vllm_omni.diffusion.forward_context import get_forward_context logger = init_logger(__name__) @@ -108,8 +114,8 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, channel, height, width = hidden_states.shape - post_patch_height = height // self.patch_size - post_patch_width = width // self.patch_size + post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64) + post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64) # Reshape: [B, C, H, W] -> [B, H', W', C*p*p] -> [B, H'*W', C*p*p] hidden_states = hidden_states.reshape( @@ -159,6 +165,65 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return (freqs.cos(), freqs.sin()) +class GlmImagePrepare(nn.Module): + """Prepare module for GLM-Image that handles patch embedding and RoPE computation. + + This module encapsulates the input processing pipeline to create a module boundary + where _sp_plan can shard outputs via split_output=True. + + Similar to Qwen-Image's ImageRopePrepare, this ensures hidden_states and RoPE + embeddings are sharded together to maintain dimension alignment. + """ + + def __init__( + self, + image_projector: nn.Module, + rope: GlmImageRotaryPosEmbed, + patch_size: int, + ): + super().__init__() + self.image_projector = image_projector + self.rope = rope + self.patch_size = patch_size + + def forward( + self, + hidden_states: torch.Tensor, + prior_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Process hidden_states and compute RoPE embeddings. + + Args: + hidden_states: Input latent tensor [B, C, H, W] + prior_hidden_states: Optional prior embedding to add + + Returns: + hidden_states: Patched hidden states [B, seq_len, D] + rope_cos: RoPE cos embeddings [seq_len, dim] + rope_sin: RoPE sin embeddings [seq_len, dim] + post_patch_height: Scalar tensor for height after patching + post_patch_width: Scalar tensor for width after patching + """ + batch_size, num_channels, height, width = hidden_states.shape + + post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64) + post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64) + + # Compute RoPE (uses original 4D hidden_states shape) + image_rotary_emb = self.rope(hidden_states) + rope_cos = image_rotary_emb[0].to(hidden_states.device) + rope_sin = image_rotary_emb[1].to(hidden_states.device) + + # Patch embedding: [B, C, H, W] -> [B, seq_len, D] + hidden_states = self.image_projector(hidden_states) + + # Add prior embedding if provided + if prior_hidden_states is not None: + hidden_states = hidden_states + prior_hidden_states + + return hidden_states, rope_cos, rope_sin, post_patch_height, post_patch_width + + class GlmImageAdaLayerNormZero(nn.Module): """Adaptive LayerNorm with zero initialization for both image and text streams.""" @@ -397,6 +462,7 @@ def __init__( dim: int, num_heads: int, head_dim: int, + parallel_config: DiffusionParallelConfig | None = None, out_bias: bool = True, eps: float = 1e-5, ): @@ -404,6 +470,7 @@ def __init__( self.dim = dim self.total_num_heads = num_heads self.head_dim = head_dim + self.parallel_config = parallel_config # QKV projection (fused for efficiency) self.to_qkv = QKVParallelLinear( @@ -450,16 +517,19 @@ def forward( attention_mask: torch.Tensor | None = None, kv_cache: GlmImageLayerKVCache | None = None, kv_cache_mode: KVCacheMode | None = None, + hidden_states_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for joint attention. Args: - hidden_states: Image hidden states [B, img_seq_len, D] - encoder_hidden_states: Text hidden states [B, text_seq_len, D] - image_rotary_emb: Tuple of (cos, sin) for RoPE + hidden_states: Image hidden states [B, img_seq_len, D] (sharded in SP mode) + encoder_hidden_states: Text hidden states [B, text_seq_len, D] (full in SP mode) + image_rotary_emb: Tuple of (cos, sin) for RoPE (sharded in SP mode) + attention_mask: Optional attention mask kv_cache: Optional layer KV cache for image editing kv_cache_mode: Cache mode (WRITE, READ, SKIP) + hidden_states_mask: Mask for SP padding (True=valid, False=padding) Returns: Tuple of (image_hidden_states, text_hidden_states) @@ -467,6 +537,13 @@ def forward( dtype = encoder_hidden_states.dtype batch_size, text_seq_length, _ = encoder_hidden_states.shape + # Check if SP is enabled + sp_size = self.parallel_config.sequence_parallel_size if self.parallel_config else None + use_sp = sp_size is not None and sp_size > 1 + if use_sp: + forward_ctx = get_forward_context() + use_sp = not forward_ctx.split_text_embed_in_sp + # Concatenate text and image: [text, image] hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -485,41 +562,88 @@ def forward( query = self.norm_q(query).to(dtype=dtype) key = self.norm_k(key).to(dtype=dtype) - # Apply RoPE only to image tokens (not text tokens) - if image_rotary_emb is not None: - # Only apply RoPE to image part (after text_seq_length) - query_img = query[:, text_seq_length:, :, :] - key_img = key[:, text_seq_length:, :, :] - from diffusers.models.embeddings import apply_rotary_emb - - query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) - key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) - query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) - key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) - - # Handle KV cache for image editing - if kv_cache is not None and kv_cache_mode is not None: - if kv_cache_mode == KVCacheMode.WRITE: - kv_cache.store(key, value) - elif kv_cache_mode == KVCacheMode.READ: - k_cached, v_cached = kv_cache.get() - if k_cached is not None: - key = torch.cat([k_cached, key], dim=1) - value = torch.cat([v_cached, value], dim=1) - # KVCacheMode.SKIP: do nothing - - # Attention computation - hidden_states_out = self.attn(query, key, value) - hidden_states_out = hidden_states_out.flatten(2, 3) - hidden_states_out = hidden_states_out.to(dtype) + if use_sp: + # SP mode: use joint attention mechanism + # Split Q/K/V into text and image parts + text_query = query[:, :text_seq_length, :, :] + text_key = key[:, :text_seq_length, :, :] + text_value = value[:, :text_seq_length, :, :] + img_query = query[:, text_seq_length:, :, :] + img_key = key[:, text_seq_length:, :, :] + img_value = value[:, text_seq_length:, :, :] + + # Apply RoPE only to image part + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + img_query = apply_rotary_emb(img_query, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + img_key = apply_rotary_emb(img_key, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + + # Create attention metadata for joint attention + attn_metadata = AttentionMetadata( + joint_query=text_query, + joint_key=text_key, + joint_value=text_value, + joint_strategy="front", + ) - # Output projection - for module in self.to_out: - hidden_states_out = module(hidden_states_out) + # Add padding mask for SP if available + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + + # Attention computation with joint text/image + # Note: Ulysses post_attention returns [text, image] concatenated + joint_hidden_states_out = self.attn(img_query, img_key, img_value, attn_metadata) + + # Project combined [text, image] outputs, then split. + # This keeps SP numerically aligned with the non-SP path. + joint_hidden_states_out = joint_hidden_states_out.flatten(2, 3).to(dtype) + for module in self.to_out: + joint_hidden_states_out = module(joint_hidden_states_out) - # Split back to text and image - encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] - hidden_states_out = hidden_states_out[:, text_seq_length:, :] + encoder_hidden_states_out = joint_hidden_states_out[:, :text_seq_length, :] + hidden_states_out = joint_hidden_states_out[:, text_seq_length:, :] + else: + # Non-SP mode: original logic + # Apply RoPE only to image tokens (not text tokens) + if image_rotary_emb is not None: + query_img = query[:, text_seq_length:, :, :] + key_img = key[:, text_seq_length:, :, :] + from diffusers.models.embeddings import apply_rotary_emb + + query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) + key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) + + # Handle KV cache for image editing + if kv_cache is not None and kv_cache_mode is not None: + if kv_cache_mode == KVCacheMode.WRITE: + kv_cache.store(key, value) + elif kv_cache_mode == KVCacheMode.READ: + k_cached, v_cached = kv_cache.get() + if k_cached is not None: + key = torch.cat([k_cached, key], dim=1) + value = torch.cat([v_cached, value], dim=1) + + # Attention computation + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states_out = self.attn(query, key, value, attn_metadata) + hidden_states_out = hidden_states_out.flatten(2, 3) + hidden_states_out = hidden_states_out.to(dtype) + + # Output projection + for module in self.to_out: + hidden_states_out = module(hidden_states_out) + + # Split back to text and image + encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] + hidden_states_out = hidden_states_out[:, text_seq_length:, :] return hidden_states_out, encoder_hidden_states_out @@ -628,6 +752,7 @@ def __init__( attention_head_dim: int = 40, time_embed_dim: int = 512, ffn_hidden_dim: int | None = None, + parallel_config: DiffusionParallelConfig | None = None, ) -> None: super().__init__() @@ -637,6 +762,7 @@ def __init__( dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, + parallel_config=parallel_config, ) # 2. Feedforward @@ -654,6 +780,7 @@ def forward( attention_kwargs: dict[str, Any] | None = None, kv_cache: GlmImageLayerKVCache | None = None, kv_cache_mode: KVCacheMode | None = None, + hidden_states_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for transformer block. @@ -667,6 +794,7 @@ def forward( attention_kwargs: Additional attention arguments kv_cache: Layer-specific KV cache for image editing kv_cache_mode: Cache mode (WRITE, READ, SKIP) + hidden_states_mask: Mask for SP padding (True=valid, False=padding) Returns: Tuple of (image_hidden_states, text_hidden_states) @@ -693,6 +821,7 @@ def forward( attention_mask=attention_mask, kv_cache=kv_cache, kv_cache_mode=kv_cache_mode, + hidden_states_mask=hidden_states_mask, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) @@ -724,6 +853,26 @@ class GlmImageTransformer2DModel(CachedTransformer): """ _repeated_blocks = ["GlmImageTransformerBlock"] + # SP plan using GlmImagePrepare module for sharding hidden_states and RoPE together. + # Similar to Qwen-Image's ImageRopePrepare, this creates a module boundary where + # _sp_plan can shard outputs via split_output=True. + # + # Key insight: hidden_states and RoPE embeddings MUST be sharded together + # to maintain dimension alignment for RoPE computation in attention layers. + _sp_plan = { + # Shard GlmImagePrepare outputs (hidden_states and RoPE must be sharded together) + "prepare": { + # hidden_states: [B, seq_len, D] - shard along sequence dimension + 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True), + # RoPE cos: [seq_len, dim] - shard along sequence dimension + 1: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + # RoPE sin: [seq_len, dim] - shard along sequence dimension + 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + # post_patch_height and post_patch_width are scalars, not sharded + }, + # Gather output at proj_out + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } _hsdp_shard_conditions = [is_transformer_block_module] @@ -790,6 +939,9 @@ def __init__( dim=inner_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="linear-silu" ) + # Prepare module for SP (encapsulates patch embedding and RoPE for _sp_plan) + self.prepare = GlmImagePrepare(self.image_projector, self.rope, patch_size) + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( embedding_dim=time_embed_dim, condition_dim=condition_dim, @@ -806,6 +958,7 @@ def __init__( attention_head_dim, time_embed_dim, ffn_hidden_dim=ffn_hidden_dim, + parallel_config=self.parallel_config, ) for _ in range(num_layers) ] @@ -859,33 +1012,51 @@ def forward( # Get KV cache mode kv_cache_mode = kv_cache.mode if kv_cache is not None else None - # 1. RoPE - if image_rotary_emb is None: - image_rotary_emb = self.rope(hidden_states) - # Move to correct device - image_rotary_emb = ( - image_rotary_emb[0].to(hidden_states.device), - image_rotary_emb[1].to(hidden_states.device), - ) - - # 2. Patch & Timestep embeddings - p = self.patch_size - post_patch_height = height // p - post_patch_width = width // p + # Set SP context if enabled + sp_size = self.parallel_config.sequence_parallel_size + if sp_size is not None and sp_size > 1: + get_forward_context().split_text_embed_in_sp = False - hidden_states = self.image_projector(hidden_states) + # Text embedding projection encoder_hidden_states = self.glyph_projector(encoder_hidden_states) # Prior embedding with dropout prior_embedding = self.prior_token_embedding(prior_token_id) prior_embedding[prior_token_drop] *= 0.0 prior_hidden_states = self.prior_projector(prior_embedding) - hidden_states = hidden_states + prior_hidden_states + + # 1. Prepare hidden_states and RoPE via GlmImagePrepare module + # _sp_plan will shard hidden_states and RoPE together via split_output=True + hidden_states, rope_cos, rope_sin, post_patch_height_t, post_patch_width_t = self.prepare( + hidden_states, prior_hidden_states + ) + image_rotary_emb = (rope_cos, rope_sin) + post_patch_height = int(post_patch_height_t.item()) + post_patch_width = int(post_patch_width_t.item()) # Timestep conditioning temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) - # 3. Transformer blocks + # Create padding mask for SP if needed (after _sp_plan hooks have run) + hidden_states_mask = None + if sp_size is not None and sp_size > 1: + from vllm_omni.diffusion.forward_context import is_forward_context_available + + if is_forward_context_available(): + ctx = get_forward_context() + if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: + img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size + hidden_states_mask = torch.ones( + batch_size, + img_padded_seq_len, + dtype=torch.bool, + device=hidden_states.device, + ) + hidden_states_mask[:, ctx.sp_original_seq_len :] = False + if hidden_states_mask.all(): + hidden_states_mask = None + + # 2. Transformer blocks for layer_idx, block in enumerate(self.transformer_blocks): # Get layer-specific KV cache if available layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None @@ -899,13 +1070,16 @@ def forward( attention_kwargs, kv_cache=layer_kv_cache, kv_cache_mode=kv_cache_mode, + hidden_states_mask=hidden_states_mask, ) - # 4. Output norm & projection + # 3. Output norm & projection + # _sp_plan will gather hidden_states via proj_out hook hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - # 5. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W] + # 4. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W] + p = self.patch_size hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py index 375f7e7b80..0386364998 100644 --- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -712,6 +712,14 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: if img is not None: preprocessed_images = [img] + # Priority: prompt dict (from ar2diffusion) > sampling_params + # ar2diffusion returns adjusted height/width that matches prior_token_ids + if not isinstance(first_prompt, str): + ar_height = first_prompt.get("height") + ar_width = first_prompt.get("width") + else: + ar_height = ar_width = None + img_height = req.sampling_params.height img_width = req.sampling_params.width @@ -719,12 +727,19 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Treat that as t2i warmup to avoid requiring i2i-only KV-cache inputs. is_image_edit = (preprocessed_images is not None) and (not is_dummy_warmup) - # Use image dimensions as default if available - height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor - width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor + # Use prompt dict dimensions (from ar2diffusion) as priority, then sampling_params + height = ( + ar_height or req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor + ) + width = ar_width or req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor num_inference_steps = req.sampling_params.num_inference_steps or 50 guidance_scale = req.sampling_params.guidance_scale or 1.5 + # Ensure dimensions are multiples of vae_scale_factor * patch_size + multiple_of = self.vae_scale_factor * self._patch_size + height = height // multiple_of * multiple_of + width = width // multiple_of * multiple_of + self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) batch_size = 1 @@ -753,6 +768,20 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: prior_token_id = prior_token_id.to(device=self.device, dtype=torch.long) if prior_token_id.dim() == 1: prior_token_id = prior_token_id.unsqueeze(0) + + # Validate that prior_token_id seq_len matches dimensions + prior_seq_len = prior_token_id.shape[1] + expected_seq_len = (height // self.vae_scale_factor // self._patch_size) * ( + width // self.vae_scale_factor // self._patch_size + ) + if prior_seq_len != expected_seq_len: + raise ValueError( + f"prior_token_ids seq_len ({prior_seq_len}) doesn't match dimensions " + f"({height}x{width}, expected seq_len={expected_seq_len}). " + f"This indicates a mismatch between AR output and Diffusion input. " + f"Please ensure ar2diffusion returns correct height/width." + ) + prior_token_image_ids = None if external_prior_image_ids is not None: if isinstance(external_prior_image_ids, torch.Tensor):