diff --git a/inference.py b/inference.py index 12609c8..442ec7f 100644 --- a/inference.py +++ b/inference.py @@ -153,6 +153,17 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable group offload." ) + parser.add_argument( + "--enable_teacache", + action="store_true", + help="Enable teacache to speed up inference." + ) + parser.add_argument( + "--rel_l1_thresh", + type=float, + default=0.05, + help="Relative L1 threshold for teacache." + ) return parser.parse_args() def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline: @@ -161,6 +172,7 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty torch_dtype=weight_dtype, trust_remote_code=True, ) + if args.transformer_path: print(f"Transformer weights loaded from {args.transformer_path}") pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained( @@ -178,6 +190,10 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty print(f"LoRA weights loaded from {args.transformer_lora_path}") pipeline.load_lora_weights(args.transformer_lora_path) + if args.enable_teacache: + pipeline.transformer.enable_teacache = True + pipeline.transformer.rel_l1_thresh = args.rel_l1_thresh + if args.scheduler == "dpmsolver++": from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler scheduler = DPMSolverMultistepScheduler( diff --git a/omnigen2/models/transformers/transformer_omnigen2.py b/omnigen2/models/transformers/transformer_omnigen2.py index f826dc3..c80c2b8 100644 --- a/omnigen2/models/transformers/transformer_omnigen2.py +++ b/omnigen2/models/transformers/transformer_omnigen2.py @@ -2,6 +2,8 @@ import itertools from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np + import torch import torch.nn as nn @@ -20,6 +22,7 @@ from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding from ...utils.import_utils import is_triton_available, is_flash_attn_available +from ...utils.teacache_util import TeaCacheParams if is_triton_available(): from ...ops.triton.layer_norm import RMSNorm @@ -28,7 +31,6 @@ logger = logging.get_logger(__name__) - class OmniGen2TransformerBlock(nn.Module): """ Transformer block for OmniGen2 model. @@ -342,6 +344,14 @@ def __init__( self.initialize_weights() + # TeaCache settings + self.enable_teacache = False + self.rel_l1_thresh = 0.05 + self.teacache_params = TeaCacheParams() + + coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487] + self.rescale_func = np.poly1d(coefficients) + def initialize_weights(self) -> None: """ Initialize the weights of the model. @@ -589,13 +599,46 @@ def forward( hidden_states = joint_hidden_states - for layer_idx, layer in enumerate(self.layers): - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - layer, hidden_states, attention_mask, rotary_emb, temb + if self.enable_teacache: + teacache_hidden_states = hidden_states.clone() + teacache_temb = temb.clone() + modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb) + if self.teacache_params.is_first_or_last_step: + should_calc = True + self.teacache_params.accumulated_rel_l1_distance = 0 + else: + self.teacache_params.accumulated_rel_l1_distance += self.rescale_func( + ((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \ + / self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item() ) + if self.teacache_params.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.teacache_params.accumulated_rel_l1_distance = 0 + self.teacache_params.previous_modulated_inp = modulated_inp + + if self.enable_teacache: + if not should_calc: + hidden_states += self.teacache_params.previous_residual else: - hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + ori_hidden_states = hidden_states.clone() + for layer_idx, layer in enumerate(self.layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + self.teacache_params.previous_residual = hidden_states - ori_hidden_states + else: + for layer_idx, layer in enumerate(self.layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) # 4. Output norm & projection hidden_states = self.norm_out(hidden_states, temb) @@ -614,4 +657,4 @@ def forward( if not return_dict: return output - return Transformer2DModelOutput(sample=output) + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py b/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py index 91df3d9..a76d744 100644 --- a/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py +++ b/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py @@ -46,8 +46,12 @@ from diffusers.utils import BaseOutput from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor + +from omnigen2.utils.teacache_util import TeaCacheParams + from ..lora_pipeline import OmniGen2LoraLoaderMixin + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -632,8 +636,19 @@ def processing( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # Use different TeaCacheParams for different conditions + if self.transformer.enable_teacache: + teacache_params = TeaCacheParams() + teacache_params_uncond = TeaCacheParams() + teacache_params_ref = TeaCacheParams() + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + + if self.transformer.enable_teacache: + teacache_params.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params + model_pred = self.predict( t=t, latents=latents, @@ -646,6 +661,11 @@ def processing( image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + + if self.transformer.enable_teacache: + teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_ref + model_pred_ref = self.predict( t=t, latents=latents, @@ -656,6 +676,11 @@ def processing( ) if image_guidance_scale != 1: + + if self.transformer.enable_teacache: + teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_uncond + model_pred_uncond = self.predict( t=t, latents=latents, @@ -670,6 +695,11 @@ def processing( model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ text_guidance_scale * (model_pred - model_pred_ref) elif text_guidance_scale > 1.0: + + if self.transformer.enable_teacache: + teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_uncond + model_pred_uncond = self.predict( t=t, latents=latents, diff --git a/omnigen2/utils/teacache_util.py b/omnigen2/utils/teacache_util.py new file mode 100644 index 0000000..b997ff4 --- /dev/null +++ b/omnigen2/utils/teacache_util.py @@ -0,0 +1,43 @@ +""" +Utility for TeaCache + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dataclasses import dataclass +from typing import Optional + +import torch + +@dataclass +class TeaCacheParams: + """ + TeaCache parameters for `OmniGen2Transformer2DModel` + See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding + + Args: + previous_residual (Optional[torch.Tensor]): + The tensor difference between the output and the input of the transformer layers from the previous timestep. + previous_modulated_inp (Optional[torch.Tensor]): + The modulated input from the previous timestep used to indicate the change of the transformer layer's output. + accumulated_rel_l1_distance (float): + The accumulated relative L1 distance. + is_first_or_last_step (bool): + Whether the current timestep is the first or last step. + """ + previous_residual: Optional[torch.Tensor] = None + previous_modulated_inp: Optional[torch.Tensor] = None + accumulated_rel_l1_distance: float = 0 + is_first_or_last_step: bool = False \ No newline at end of file