Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm_omni/diffusion/cache/teacache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@
# Bagel transformer coefficients
# Using Qwen's coefficients as reasonable default given shared architecture
"Bagel": [1.33313129e06, -1.68644226e05, 7.95050740e03, -1.63747873e02, 1.26352397e00],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test plan says results will be provided "as soon as possible" — are these coefficients validated at all yet? Shipping untuned polynomial coefficients could silently degrade output quality. Would be good to at least confirm basic generation quality before merging.

# OmniGen2 transformer coefficients
# Copied from Qwen-Image, need to be tuned specifically for OmniGen2 in future
"OmniGen2Transformer2DModel": [
-4.50000000e02,
2.80000000e02,
-4.50000000e01,
3.20000000e00,
-2.00000000e-02,
],
# Z-Image transformer coefficients
# Copied from Qwen-Image, need to be tuned specifically for Z-Image in future
"ZImageTransformer2DModel": [
Expand Down
168 changes: 168 additions & 0 deletions vllm_omni/diffusion/cache/teacache/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,173 @@ def postprocess(h):
)


def extract_omnigen2_context(
module: nn.Module,
hidden_states: torch.Tensor | list[torch.Tensor],
timestep: torch.Tensor,
text_hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
text_attention_mask: torch.Tensor,
ref_image_hidden_states: list | None = None,
return_dict: bool = False,
**kwargs: Any,
) -> CacheContext:
"""
Extract cache context for OmniGen2Transformer2DModel.

This is the ONLY OmniGen2-specific code needed for TeaCache support.
It encapsulates preprocessing, modulated input extraction, transformer execution,
and postprocessing logic.

Args:
module: OmniGen2Transformer2DModel instance
hidden_states: List of image tensors per batch item, or batched tensor [B, C, H, W]
timestep: Timestep tensor
text_hidden_states: Text encoder hidden states
freqs_cis: Precomputed rotary frequency tensor
text_attention_mask: Attention mask for text tokens
ref_image_hidden_states: Optional reference image hidden states
return_dict: Whether to return Transformer2DModelOutput (passed through to postprocess)

Returns:
CacheContext with all information needed for generic caching
"""
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from einops import rearrange

if not hasattr(module, "layers") or len(module.layers) == 0:
raise ValueError("Module must have main transformer layers")

# ============================================================================
# PREPROCESSING (OmniGen2-specific)
# ============================================================================
Comment thread
zzhuoxin1508 marked this conversation as resolved.
batch_size = len(hidden_states)
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)

if is_hidden_states_tensor:
assert hidden_states.ndim == 4
hidden_states = [_hidden_states for _hidden_states in hidden_states]

device = hidden_states[0].device

temb, text_hidden_states = module.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)

# Flatten and pad images to sequence
(
hidden_states,
ref_image_hidden_states,
img_mask,
ref_img_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
) = module.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)

# Compute rotary embeddings and sequence lengths
(
context_rotary_emb,
ref_img_rotary_emb,
noise_rotary_emb,
rotary_emb,
encoder_seq_lengths,
seq_lengths,
) = module.rope_embedder(
freqs_cis,
text_attention_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
device,
)

# Context refinement (text)
for layer in module.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)

# Image patch embed + noise refiner
combined_img_hidden_states = module.img_patch_embed_and_refine(
hidden_states,
ref_image_hidden_states,
img_mask,
ref_img_mask,
noise_rotary_emb,
ref_img_rotary_emb,
l_effective_ref_img_len,
l_effective_img_len,
temb,
)

# Build joint (text + image) sequence
max_seq_len = max(seq_lengths)
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, module.config.hidden_size)
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
attention_mask[i, :seq_len] = True
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, : seq_len - encoder_seq_len]

# ============================================================================
# EXTRACT MODULATED INPUT (for cache decision)
# ============================================================================
# Use the first main transformer block's LuminaRMSNormZero modulation.
# norm1.forward returns (norm_hidden_states, gate_msa, scale_mlp, gate_mlp);
# the first element is the modulated input that gets passed to attention.
block = module.layers[0]
modulated_input = block.norm1(joint_hidden_states, temb)[0]

# ============================================================================
# DEFINE TRANSFORMER EXECUTION (OmniGen2-specific)
# ============================================================================
def run_transformer_blocks():
"""Execute all OmniGen2 main transformer blocks."""
h = joint_hidden_states
for layer in module.layers:
h = layer(h, attention_mask, rotary_emb, temb)
return (h,)

# ============================================================================
# DEFINE POSTPROCESSING (OmniGen2-specific)
# ============================================================================
def postprocess(h):
"""Apply OmniGen2-specific output postprocessing."""
h = module.norm_out(h, temb)

p = module.config.patch_size
output = []
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
height, width = img_size
output.append(
rearrange(
h[i][seq_len - img_len : seq_len],
"(h w) (p1 p2 c) -> c (h p1) (w p2)",
h=height // p,
w=width // p,
p1=p,
p2=p,
)
)
if is_hidden_states_tensor:
output = torch.stack(output, dim=0)

if not return_dict:
return output
return Transformer2DModelOutput(sample=output)

# ============================================================================
# RETURN CONTEXT
# ============================================================================
return CacheContext(
modulated_input=modulated_input,
hidden_states=joint_hidden_states,
encoder_hidden_states=None, # OmniGen2 uses unified joint sequence, no separate encoder states
temb=temb,
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
)


Comment thread
zzhuoxin1508 marked this conversation as resolved.
def extract_flux2_klein_context(
module: nn.Module,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -976,6 +1143,7 @@ def postprocess(h):
"QwenImageTransformer2DModel": extract_qwen_context,
"Bagel": extract_bagel_context,
"ZImageTransformer2DModel": extract_zimage_context,
"OmniGen2Transformer2DModel": extract_omnigen2_context,
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
"Flux2Transformer2DModel": extract_flux2_context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.cache.base import CachedTransformer
from vllm_omni.platforms import current_omni_platform

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -668,7 +669,7 @@ def forward(
return hidden_states


class OmniGen2Transformer2DModel(nn.Module):
class OmniGen2Transformer2DModel(CachedTransformer):
"""
OmniGen2 Transformer 2D Model.

Expand Down
Loading