diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index ecf3bfc1d3d..92b9577849f 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -37,6 +37,15 @@ # Bagel transformer coefficients # Using Qwen's coefficients as reasonable default given shared architecture "Bagel": [1.33313129e06, -1.68644226e05, 7.95050740e03, -1.63747873e02, 1.26352397e00], + # OmniGen2 transformer coefficients + # Copied from Qwen-Image, need to be tuned specifically for OmniGen2 in future + "OmniGen2Transformer2DModel": [ + -4.50000000e02, + 2.80000000e02, + -4.50000000e01, + 3.20000000e00, + -2.00000000e-02, + ], # Z-Image transformer coefficients # Copied from Qwen-Image, need to be tuned specifically for Z-Image in future "ZImageTransformer2DModel": [ diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 3d247e31878..9c88371398b 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -579,6 +579,173 @@ def postprocess(h): ) +def extract_omnigen2_context( + module: nn.Module, + hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + text_attention_mask: torch.Tensor, + ref_image_hidden_states: list | None = None, + return_dict: bool = False, + **kwargs: Any, +) -> CacheContext: + """ + Extract cache context for OmniGen2Transformer2DModel. + + This is the ONLY OmniGen2-specific code needed for TeaCache support. + It encapsulates preprocessing, modulated input extraction, transformer execution, + and postprocessing logic. + + Args: + module: OmniGen2Transformer2DModel instance + hidden_states: List of image tensors per batch item, or batched tensor [B, C, H, W] + timestep: Timestep tensor + text_hidden_states: Text encoder hidden states + freqs_cis: Precomputed rotary frequency tensor + text_attention_mask: Attention mask for text tokens + ref_image_hidden_states: Optional reference image hidden states + return_dict: Whether to return Transformer2DModelOutput (passed through to postprocess) + + Returns: + CacheContext with all information needed for generic caching + """ + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from einops import rearrange + + if not hasattr(module, "layers") or len(module.layers) == 0: + raise ValueError("Module must have main transformer layers") + + # ============================================================================ + # PREPROCESSING (OmniGen2-specific) + # ============================================================================ + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + assert hidden_states.ndim == 4 + hidden_states = [_hidden_states for _hidden_states in hidden_states] + + device = hidden_states[0].device + + temb, text_hidden_states = module.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) + + # Flatten and pad images to sequence + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = module.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + # Compute rotary embeddings and sequence lengths + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = module.rope_embedder( + freqs_cis, + text_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # Context refinement (text) + for layer in module.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + + # Image patch embed + noise refiner + combined_img_hidden_states = module.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + # Build joint (text + image) sequence + max_seq_len = max(seq_lengths) + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, module.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, : seq_len - encoder_seq_len] + + # ============================================================================ + # EXTRACT MODULATED INPUT (for cache decision) + # ============================================================================ + # Use the first main transformer block's LuminaRMSNormZero modulation. + # norm1.forward returns (norm_hidden_states, gate_msa, scale_mlp, gate_mlp); + # the first element is the modulated input that gets passed to attention. + block = module.layers[0] + modulated_input = block.norm1(joint_hidden_states, temb)[0] + + # ============================================================================ + # DEFINE TRANSFORMER EXECUTION (OmniGen2-specific) + # ============================================================================ + def run_transformer_blocks(): + """Execute all OmniGen2 main transformer blocks.""" + h = joint_hidden_states + for layer in module.layers: + h = layer(h, attention_mask, rotary_emb, temb) + return (h,) + + # ============================================================================ + # DEFINE POSTPROCESSING (OmniGen2-specific) + # ============================================================================ + def postprocess(h): + """Apply OmniGen2-specific output postprocessing.""" + h = module.norm_out(h, temb) + + p = module.config.patch_size + output = [] + for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): + height, width = img_size + output.append( + rearrange( + h[i][seq_len - img_len : seq_len], + "(h w) (p1 p2 c) -> c (h p1) (w p2)", + h=height // p, + w=width // p, + p1=p, + p2=p, + ) + ) + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) + + # ============================================================================ + # RETURN CONTEXT + # ============================================================================ + return CacheContext( + modulated_input=modulated_input, + hidden_states=joint_hidden_states, + encoder_hidden_states=None, # OmniGen2 uses unified joint sequence, no separate encoder states + temb=temb, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) + + def extract_flux2_klein_context( module: nn.Module, hidden_states: torch.Tensor, @@ -976,6 +1143,7 @@ def postprocess(h): "QwenImageTransformer2DModel": extract_qwen_context, "Bagel": extract_bagel_context, "ZImageTransformer2DModel": extract_zimage_context, + "OmniGen2Transformer2DModel": extract_omnigen2_context, "Flux2Klein": extract_flux2_klein_context, "StableAudioDiTModel": extract_stable_audio_context, "Flux2Transformer2DModel": extract_flux2_context, diff --git a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py index 9ff681a3c0b..653e7406410 100644 --- a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py +++ b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py @@ -19,6 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.cache.base import CachedTransformer from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -668,7 +669,7 @@ def forward( return hidden_states -class OmniGen2Transformer2DModel(nn.Module): +class OmniGen2Transformer2DModel(CachedTransformer): """ OmniGen2 Transformer 2D Model.