[Bagel] Support Cache-Dit#736
Conversation
Signed-off-by: princepride <wangzhipeng628@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 582c8b1fd3
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| prefix=( | ||
| f"{self.cache_prefix}_{block_id}_Bn_residual" | ||
| if self.context_manager.is_cache_residual() | ||
| else f"{self.cache_prefix}_Bn_hidden_states" |
There was a problem hiding this comment.
Fix prune cache key mismatch for non-residual mode
When pruning is used with non-residual caching, apply_prune looks up the Bn buffer under f"{self.cache_prefix}_Bn_hidden_states", but the buffer is stored later with f"{self.cache_prefix}_{block_id}_Bn_hidden_states". This mismatch means the prune path will fail to find the cached hidden states (likely raising an assertion or reusing the wrong data) whenever is_cache_residual() is false and _maybe_prune() returns true (e.g., SCM/pruning enabled). Aligning the prefix with the one used in set_Bn_buffer avoids the incorrect cache lookup.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
I directly copied from cache-dit
| return refresh_cache_context | ||
|
|
||
|
|
||
| class BagelCachedContextManager(CachedContextManager): |
There was a problem hiding this comment.
Do you think it's feasible to add it directly in cache-dit, so that cache-dit repo can support bagel natively as well? @DefTruth
There was a problem hiding this comment.
That sounds good to me. Feel free to send a PR to cache-dit, and I can help review it.
There was a problem hiding this comment.
Bagel's original code is extremely complex, and it doesn't name its modules according to the standard Dit model. I can only guarantee that the code I submit will support Bagel's acceleration on vLLM-Omni.
There was a problem hiding this comment.
BTW, I believe hunyuan image will also benefit from this new feature.
There was a problem hiding this comment.
@SamitHuang @hsliuustc0106 What do you think?
There was a problem hiding this comment.
indeed, diffusers doesn't support bagel or HYI3.0. CacheDiT doesn't plan to support image/video generation models in transformers, right? @DefTruth
There was a problem hiding this comment.
since cache-dit is built on top of the diffusers library, we have long focused mainly on the models supported in diffusers.
There was a problem hiding this comment.
so the conclusion is still keeping the change in vllm-omni here?
There was a problem hiding this comment.
so the conclusion is still keeping the change in vllm-omni here?
I'm afraid so.
|
please update the diffusion acceleration md for this support |
Signed-off-by: princepride <wangzhipeng628@gmail.com>
|
@hsliuustc0106 Already update docs. |
|
@tjtanaa amd CI failed, PTAL. |
Signed-off-by: princepride <wangzhipeng628@gmail.com> Signed-off-by: Chen Yang <2082464740@qq.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Purpose
This PR enables
cache-ditacceleration for the Bagel model invllm-omni.The integration presents a unique challenge: Bagel leverages a
NaiveCacheobject passed viapast_key_valuesfor its KV-cache mechanism, whereascache-dit's standardForwardPatternexpects anencoder_hidden_statestensor (typically used for cross-attention in DiTs) and attempts to perform tensor operations (like residual calculation) on it.To resolve this without modifying the original
cache-ditlibrary or the Bagel model definition, this PR implements a set of custom adapter classes that inherit from and extendcache-dit's core components:BagelCachedContextManager: Inherits fromCachedContextManager. It overridesapply_cacheto addisinstance(..., torch.Tensor)checks. This ensures that whenNaiveCacheobjects are passed through the pipeline (masquerading asencoder_hidden_statesin the pattern), they are preserved as-is without crashing due to invalid tensor operations (like.contiguous()or subtraction).BagelCachedBlocks: Inherits fromCachedBlocks_Pattern_0_1_2.cache-dit's expectedencoder_hidden_statesand Bagel'spast_key_values. Theforwardmethod interceptspast_key_valuesfrom kwargs, and internal block call methods (call_Fn/Mn/Bn_blocks) ensure the underlying Bagel decoder layer receives arguments in the correct keyword format preventingTypeError.call_Mn_blocksandcompute_or_pruneto skip residual calculation for non-tensorpast_key_values, protecting theNaiveCache.BagelCachedAdapter: Inherits fromCachedAdapter. It serves as the entry point, orchestrating the instantiation of the customBagelCachedContextManagerandBagelCachedBlocksinstead of the default ones.These adapters are defined in
vllm_omni/diffusion/cache/cache_dit_backend.pyand applied specifically for Bagel pipelines.Test Plan
Test Command:
[Optional] You can directly use
examples/offline_inference/text_to_image/text_to_image.pyTest Result
Performance Comparison:
Visual Validation:
Baseline (No Cache)
Cache-Dit Enabled