Skip to content

Commit d21b133

Browse files
authored
Merge pull request #52 from legitnull/teacache_pr
feat: add support for teacache
2 parents dd92bad + 6e936ec commit d21b133

File tree

4 files changed

+139
-7
lines changed

4 files changed

+139
-7
lines changed

inference.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ def parse_args() -> argparse.Namespace:
153153
action="store_true",
154154
help="Enable group offload."
155155
)
156+
parser.add_argument(
157+
"--enable_teacache",
158+
action="store_true",
159+
help="Enable teacache to speed up inference."
160+
)
161+
parser.add_argument(
162+
"--rel_l1_thresh",
163+
type=float,
164+
default=0.05,
165+
help="Relative L1 threshold for teacache."
166+
)
156167
return parser.parse_args()
157168

158169
def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
@@ -161,6 +172,7 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty
161172
torch_dtype=weight_dtype,
162173
trust_remote_code=True,
163174
)
175+
164176
if args.transformer_path:
165177
print(f"Transformer weights loaded from {args.transformer_path}")
166178
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
@@ -178,6 +190,10 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty
178190
print(f"LoRA weights loaded from {args.transformer_lora_path}")
179191
pipeline.load_lora_weights(args.transformer_lora_path)
180192

193+
if args.enable_teacache:
194+
pipeline.transformer.enable_teacache = True
195+
pipeline.transformer.rel_l1_thresh = args.rel_l1_thresh
196+
181197
if args.scheduler == "dpmsolver++":
182198
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
183199
scheduler = DPMSolverMultistepScheduler(

omnigen2/models/transformers/transformer_omnigen2.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import itertools
33
from typing import Any, Dict, List, Optional, Tuple, Union
44

5+
import numpy as np
6+
57
import torch
68
import torch.nn as nn
79

@@ -20,6 +22,7 @@
2022
from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
2123

2224
from ...utils.import_utils import is_triton_available, is_flash_attn_available
25+
from ...utils.teacache_util import TeaCacheParams
2326

2427
if is_triton_available():
2528
from ...ops.triton.layer_norm import RMSNorm
@@ -28,7 +31,6 @@
2831

2932
logger = logging.get_logger(__name__)
3033

31-
3234
class OmniGen2TransformerBlock(nn.Module):
3335
"""
3436
Transformer block for OmniGen2 model.
@@ -342,6 +344,14 @@ def __init__(
342344

343345
self.initialize_weights()
344346

347+
# TeaCache settings
348+
self.enable_teacache = False
349+
self.rel_l1_thresh = 0.05
350+
self.teacache_params = TeaCacheParams()
351+
352+
coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
353+
self.rescale_func = np.poly1d(coefficients)
354+
345355
def initialize_weights(self) -> None:
346356
"""
347357
Initialize the weights of the model.
@@ -589,13 +599,46 @@ def forward(
589599

590600
hidden_states = joint_hidden_states
591601

592-
for layer_idx, layer in enumerate(self.layers):
593-
if torch.is_grad_enabled() and self.gradient_checkpointing:
594-
hidden_states = self._gradient_checkpointing_func(
595-
layer, hidden_states, attention_mask, rotary_emb, temb
602+
if self.enable_teacache:
603+
teacache_hidden_states = hidden_states.clone()
604+
teacache_temb = temb.clone()
605+
modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
606+
if self.teacache_params.is_first_or_last_step:
607+
should_calc = True
608+
self.teacache_params.accumulated_rel_l1_distance = 0
609+
else:
610+
self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
611+
((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \
612+
/ self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
596613
)
614+
if self.teacache_params.accumulated_rel_l1_distance < self.rel_l1_thresh:
615+
should_calc = False
616+
else:
617+
should_calc = True
618+
self.teacache_params.accumulated_rel_l1_distance = 0
619+
self.teacache_params.previous_modulated_inp = modulated_inp
620+
621+
if self.enable_teacache:
622+
if not should_calc:
623+
hidden_states += self.teacache_params.previous_residual
597624
else:
598-
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
625+
ori_hidden_states = hidden_states.clone()
626+
for layer_idx, layer in enumerate(self.layers):
627+
if torch.is_grad_enabled() and self.gradient_checkpointing:
628+
hidden_states = self._gradient_checkpointing_func(
629+
layer, hidden_states, attention_mask, rotary_emb, temb
630+
)
631+
else:
632+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
633+
self.teacache_params.previous_residual = hidden_states - ori_hidden_states
634+
else:
635+
for layer_idx, layer in enumerate(self.layers):
636+
if torch.is_grad_enabled() and self.gradient_checkpointing:
637+
hidden_states = self._gradient_checkpointing_func(
638+
layer, hidden_states, attention_mask, rotary_emb, temb
639+
)
640+
else:
641+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
599642

600643
# 4. Output norm & projection
601644
hidden_states = self.norm_out(hidden_states, temb)
@@ -614,4 +657,4 @@ def forward(
614657

615658
if not return_dict:
616659
return output
617-
return Transformer2DModelOutput(sample=output)
660+
return Transformer2DModelOutput(sample=output)

omnigen2/pipelines/omnigen2/pipeline_omnigen2.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@
4646
from diffusers.utils import BaseOutput
4747

4848
from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
49+
50+
from omnigen2.utils.teacache_util import TeaCacheParams
51+
4952
from ..lora_pipeline import OmniGen2LoraLoaderMixin
5053

54+
5155
if is_torch_xla_available():
5256
import torch_xla.core.xla_model as xm
5357

@@ -632,8 +636,19 @@ def processing(
632636
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
633637
self._num_timesteps = len(timesteps)
634638

639+
# Use different TeaCacheParams for different conditions
640+
if self.transformer.enable_teacache:
641+
teacache_params = TeaCacheParams()
642+
teacache_params_uncond = TeaCacheParams()
643+
teacache_params_ref = TeaCacheParams()
644+
635645
with self.progress_bar(total=num_inference_steps) as progress_bar:
636646
for i, t in enumerate(timesteps):
647+
648+
if self.transformer.enable_teacache:
649+
teacache_params.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
650+
self.transformer.teacache_params = teacache_params
651+
637652
model_pred = self.predict(
638653
t=t,
639654
latents=latents,
@@ -646,6 +661,11 @@ def processing(
646661
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
647662

648663
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
664+
665+
if self.transformer.enable_teacache:
666+
teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
667+
self.transformer.teacache_params = teacache_params_ref
668+
649669
model_pred_ref = self.predict(
650670
t=t,
651671
latents=latents,
@@ -656,6 +676,11 @@ def processing(
656676
)
657677

658678
if image_guidance_scale != 1:
679+
680+
if self.transformer.enable_teacache:
681+
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
682+
self.transformer.teacache_params = teacache_params_uncond
683+
659684
model_pred_uncond = self.predict(
660685
t=t,
661686
latents=latents,
@@ -670,6 +695,11 @@ def processing(
670695
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
671696
text_guidance_scale * (model_pred - model_pred_ref)
672697
elif text_guidance_scale > 1.0:
698+
699+
if self.transformer.enable_teacache:
700+
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
701+
self.transformer.teacache_params = teacache_params_uncond
702+
673703
model_pred_uncond = self.predict(
674704
t=t,
675705
latents=latents,

omnigen2/utils/teacache_util.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Utility for TeaCache
3+
4+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
"""
18+
19+
from dataclasses import dataclass
20+
from typing import Optional
21+
22+
import torch
23+
24+
@dataclass
25+
class TeaCacheParams:
26+
"""
27+
TeaCache parameters for `OmniGen2Transformer2DModel`
28+
See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding
29+
30+
Args:
31+
previous_residual (Optional[torch.Tensor]):
32+
The tensor difference between the output and the input of the transformer layers from the previous timestep.
33+
previous_modulated_inp (Optional[torch.Tensor]):
34+
The modulated input from the previous timestep used to indicate the change of the transformer layer's output.
35+
accumulated_rel_l1_distance (float):
36+
The accumulated relative L1 distance.
37+
is_first_or_last_step (bool):
38+
Whether the current timestep is the first or last step.
39+
"""
40+
previous_residual: Optional[torch.Tensor] = None
41+
previous_modulated_inp: Optional[torch.Tensor] = None
42+
accumulated_rel_l1_distance: float = 0
43+
is_first_or_last_step: bool = False

0 commit comments

Comments
 (0)