[diffusion] Refactor TeaCache#21613
Open
eitanturok wants to merge 16 commits intosgl-project:mainfrom
Open
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request renames teacache_params to cache_params and introduces a calibrate_cache flag with an associated CLI argument. The review identifies that the renaming is incomplete, which will cause runtime errors in other modules like teacache.py, and suggests renaming enable_teacache for consistency. Furthermore, the calibrate_cache parameter appears to be dead code as it is not utilized within the current changes.
Contributor
Author
|
@yhyang201 @mickqian can you please tag this with run-ci? |
Contributor
Author
|
@yhyang201 @mickqian @BBuf @yingluosanqian @ping1jing2 can you please tag this with run-ci? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR cleans up the TeaCache implementation.
This is a prerequisite for supporting different timestep caching methods like MagCache (#18498 #19957).
Modifications
Here are the changes in more detail:
1. Fragmented State
The Problem: Parameters were scattered across the model class,
TeaCacheMixin,TeaCacheParams, andTeaCacheContext. TheTeaCacheParamsclass confusingly had methods attached to it. These classes and parameters were all updated at different times, making it difficult to manage teacache across different requests.The Solution: Split logic into three distinct components:
TeaCacheParamsandTeaCacheStateto decide when to skip a computation.At a high level, this is what
TeaCacheStrategylooks like. (In reality, the methods take in different arguments.)2. Integration Overhead
The Problem: Adding TeaCache to a new model required ~100 lines of boilerplate code (e.g., the large amount of code deleted in
wanvideo.py).The Solution: Rewrote the TeaCache API so models now only need to pass in the modulated input. All other logic is generic and handled by
TeaCacheStrategy.3. Code Duplication:
The Problem: Handling positive and negative CFG branches required duplicated code blocks in teacache to track separate caching logic.
The Solution: We have a different
TeaCacheStatestate for positive and negative CFG branches.4. Cannot support multiple types of caching:
The Problem:
Cacheable-DiTinherited directly fromTeaCacheMixin. Since we only know which cache type we want during the forward pass not at init, the model couldn't easily support or switch between different caching strategies as we'd need to change the class we are inheriting from.The solution: Instead of
Cacheable-DiTinheriting the cache, just assign the cache toCacheable-DiT.cache(composition). It is now initialized viainit_cache(), allowing the model to dynamically support multiple caching types.5. Not torch.compile friendly:
The Problem: TeaCache called
np.poly1dand had if statements for deciding when to skip a forward pass. This would make it hard to compile this code.The Solution: Use
torch.whereinstead ofifand implementnp.poly1din torch.Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci