Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable group offload."
)
parser.add_argument(
"--enable_teacache",
action="store_true",
help="Enable teacache to speed up inference."
)
parser.add_argument(
"--rel_l1_thresh",
type=float,
default=0.05,
help="Relative L1 threshold for teacache."
)
return parser.parse_args()

def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
Expand All @@ -161,6 +172,7 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty
torch_dtype=weight_dtype,
trust_remote_code=True,
)

if args.transformer_path:
print(f"Transformer weights loaded from {args.transformer_path}")
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
Expand All @@ -178,6 +190,10 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty
print(f"LoRA weights loaded from {args.transformer_lora_path}")
pipeline.load_lora_weights(args.transformer_lora_path)

if args.enable_teacache:
pipeline.transformer.enable_teacache = True
pipeline.transformer.rel_l1_thresh = args.rel_l1_thresh

if args.scheduler == "dpmsolver++":
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler(
Expand Down
57 changes: 50 additions & 7 deletions omnigen2/models/transformers/transformer_omnigen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

import torch
import torch.nn as nn

Expand All @@ -20,6 +22,7 @@
from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding

from ...utils.import_utils import is_triton_available, is_flash_attn_available
from ...utils.teacache_util import TeaCacheParams

if is_triton_available():
from ...ops.triton.layer_norm import RMSNorm
Expand All @@ -28,7 +31,6 @@

logger = logging.get_logger(__name__)


class OmniGen2TransformerBlock(nn.Module):
"""
Transformer block for OmniGen2 model.
Expand Down Expand Up @@ -342,6 +344,14 @@ def __init__(

self.initialize_weights()

# TeaCache settings
self.enable_teacache = False
self.rel_l1_thresh = 0.05
self.teacache_params = TeaCacheParams()

coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
self.rescale_func = np.poly1d(coefficients)

def initialize_weights(self) -> None:
"""
Initialize the weights of the model.
Expand Down Expand Up @@ -589,13 +599,46 @@ def forward(

hidden_states = joint_hidden_states

for layer_idx, layer in enumerate(self.layers):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_emb, temb
if self.enable_teacache:
teacache_hidden_states = hidden_states.clone()
teacache_temb = temb.clone()
modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
if self.teacache_params.is_first_or_last_step:
should_calc = True
self.teacache_params.accumulated_rel_l1_distance = 0
else:
self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \
/ self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
)
if self.teacache_params.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.teacache_params.accumulated_rel_l1_distance = 0
self.teacache_params.previous_modulated_inp = modulated_inp

if self.enable_teacache:
if not should_calc:
hidden_states += self.teacache_params.previous_residual
else:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
ori_hidden_states = hidden_states.clone()
for layer_idx, layer in enumerate(self.layers):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_emb, temb
)
else:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
self.teacache_params.previous_residual = hidden_states - ori_hidden_states
else:
for layer_idx, layer in enumerate(self.layers):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_emb, temb
)
else:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)

# 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
Expand All @@ -614,4 +657,4 @@ def forward(

if not return_dict:
return output
return Transformer2DModelOutput(sample=output)
return Transformer2DModelOutput(sample=output)
30 changes: 30 additions & 0 deletions omnigen2/pipelines/omnigen2/pipeline_omnigen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@
from diffusers.utils import BaseOutput

from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor

from omnigen2.utils.teacache_util import TeaCacheParams

from ..lora_pipeline import OmniGen2LoraLoaderMixin


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -632,8 +636,19 @@ def processing(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# Use different TeaCacheParams for different conditions
if self.transformer.enable_teacache:
teacache_params = TeaCacheParams()
teacache_params_uncond = TeaCacheParams()
teacache_params_ref = TeaCacheParams()

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

if self.transformer.enable_teacache:
teacache_params.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params

model_pred = self.predict(
t=t,
latents=latents,
Expand All @@ -646,6 +661,11 @@ def processing(
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0

if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:

if self.transformer.enable_teacache:
teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_ref

model_pred_ref = self.predict(
t=t,
latents=latents,
Expand All @@ -656,6 +676,11 @@ def processing(
)

if image_guidance_scale != 1:

if self.transformer.enable_teacache:
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_uncond

model_pred_uncond = self.predict(
t=t,
latents=latents,
Expand All @@ -670,6 +695,11 @@ def processing(
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
text_guidance_scale * (model_pred - model_pred_ref)
elif text_guidance_scale > 1.0:

if self.transformer.enable_teacache:
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_uncond

model_pred_uncond = self.predict(
t=t,
latents=latents,
Expand Down
43 changes: 43 additions & 0 deletions omnigen2/utils/teacache_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Utility for TeaCache

Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from dataclasses import dataclass
from typing import Optional

import torch

@dataclass
class TeaCacheParams:
"""
TeaCache parameters for `OmniGen2Transformer2DModel`
See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding

Args:
previous_residual (Optional[torch.Tensor]):
The tensor difference between the output and the input of the transformer layers from the previous timestep.
previous_modulated_inp (Optional[torch.Tensor]):
The modulated input from the previous timestep used to indicate the change of the transformer layer's output.
accumulated_rel_l1_distance (float):
The accumulated relative L1 distance.
is_first_or_last_step (bool):
Whether the current timestep is the first or last step.
"""
previous_residual: Optional[torch.Tensor] = None
previous_modulated_inp: Optional[torch.Tensor] = None
accumulated_rel_l1_distance: float = 0
is_first_or_last_step: bool = False