Skip to content

[diffusion] Refactor TeaCache#21613

Open
eitanturok wants to merge 16 commits intosgl-project:mainfrom
eitanturok:refactor-teacache-2
Open

[diffusion] Refactor TeaCache#21613
eitanturok wants to merge 16 commits intosgl-project:mainfrom
eitanturok:refactor-teacache-2

Conversation

@eitanturok
Copy link
Copy Markdown
Contributor

@eitanturok eitanturok commented Mar 28, 2026

Motivation

This PR cleans up the TeaCache implementation.

  1. Teacache Parameters were split between 4 different classes, updated at different times, and had a messy get_context function to glue it all together. Now we have 3 clearly seperated classes with well-defined API boundaries.
  2. Adding Teacache to a new model now requires 5 lines instead of ~100.
  3. Remove duplicated code for CFG pos + CFG neg
  4. Implemented teacache more generically (no inheritance) so we can support different types of caching in the future
  5. Remove if statements and numpy calls so more friendly for torch.compile

This is a prerequisite for supporting different timestep caching methods like MagCache (#18498 #19957).

Modifications

Here are the changes in more detail:

1. Fragmented State

The Problem: Parameters were scattered across the model class, TeaCacheMixin, TeaCacheParams, and TeaCacheContext. The TeaCacheParams class confusingly had methods attached to it. These classes and parameters were all updated at different times, making it difficult to manage teacache across different requests.

The Solution: Split logic into three distinct components:

  • TeaCacheParams: A pure data class for user settings (thresholds, offsets). It has no methods and is updated at the start of every new request by a user.
  • TeaCacheState: Manages internal runtime data (cached tensors, accumulated $L1$ distances). The step counter is now moved here instead of being attached to the model. This is updated after every forward pass.
  • TeaCacheStrategy: The actual implementation logic that takes in TeaCacheParams and TeaCacheState to decide when to skip a computation.

At a high level, this is what TeaCacheStrategy looks like. (In reality, the methods take in different arguments.)

class TeaCacheStrategy:
	def __init__(TeaCacheParams):
	def maybe_reset(TeaCacheState): # maybe reset the cache state
	def should_skip(TeaCacheState): # decide if we should should skip computing the forward pass for this timestep
	def read(TeaCacheState): # read from the cache
	def write(TeaCacheState): # write to the cache

2. Integration Overhead

The Problem: Adding TeaCache to a new model required ~100 lines of boilerplate code (e.g., the large amount of code deleted in wanvideo.py).

The Solution: Rewrote the TeaCache API so models now only need to pass in the modulated input. All other logic is generic and handled by TeaCacheStrategy.

3. Code Duplication:

The Problem: Handling positive and negative CFG branches required duplicated code blocks in teacache to track separate caching logic.

The Solution: We have a different TeaCacheState state for positive and negative CFG branches.

4. Cannot support multiple types of caching:

The Problem: Cacheable-DiT inherited directly from TeaCacheMixin. Since we only know which cache type we want during the forward pass not at init, the model couldn't easily support or switch between different caching strategies as we'd need to change the class we are inheriting from.

The solution: Instead of Cacheable-DiT inheriting the cache, just assign the cache to Cacheable-DiT.cache (composition). It is now initialized via init_cache(), allowing the model to dynamically support multiple caching types.

5. Not torch.compile friendly:

The Problem: TeaCache called np.poly1d and had if statements for deciding when to skip a forward pass. This would make it hard to compile this code.

The Solution: Use torch.where instead of if and implement np.poly1d in torch.

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
  • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  1. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@github-actions github-actions bot added the diffusion SGLang Diffusion label Mar 28, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request renames teacache_params to cache_params and introduces a calibrate_cache flag with an associated CLI argument. The review identifies that the renaming is incomplete, which will cause runtime errors in other modules like teacache.py, and suggests renaming enable_teacache for consistency. Furthermore, the calibrate_cache parameter appears to be dead code as it is not utilized within the current changes.

@eitanturok eitanturok changed the title Refactor Teacache [diffusion[ Refactor Teacache Mar 28, 2026
@eitanturok eitanturok changed the title [diffusion[ Refactor Teacache [diffusion] Refactor TeaCache Mar 28, 2026
@eitanturok eitanturok marked this pull request as ready for review March 28, 2026 23:18
@eitanturok
Copy link
Copy Markdown
Contributor Author

@yhyang201 @mickqian can you please tag this with run-ci?

@eitanturok
Copy link
Copy Markdown
Contributor Author

@yhyang201 @mickqian @BBuf @yingluosanqian @ping1jing2 can you please tag this with run-ci?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant