Skip to content

[Bagel] Support Cache-Dit#736

Merged
hsliuustc0106 merged 3 commits into
vllm-project:mainfrom
princepride:bagel-model-cache-dit
Jan 17, 2026
Merged

[Bagel] Support Cache-Dit#736
hsliuustc0106 merged 3 commits into
vllm-project:mainfrom
princepride:bagel-model-cache-dit

Conversation

@princepride
Copy link
Copy Markdown
Collaborator

Purpose

This PR enables cache-dit acceleration for the Bagel model in vllm-omni.

The integration presents a unique challenge: Bagel leverages a NaiveCache object passed via past_key_values for its KV-cache mechanism, whereas cache-dit's standard ForwardPattern expects an encoder_hidden_states tensor (typically used for cross-attention in DiTs) and attempts to perform tensor operations (like residual calculation) on it.

To resolve this without modifying the original cache-dit library or the Bagel model definition, this PR implements a set of custom adapter classes that inherit from and extend cache-dit's core components:

  1. BagelCachedContextManager: Inherits from CachedContextManager. It overrides apply_cache to add isinstance(..., torch.Tensor) checks. This ensures that when NaiveCache objects are passed through the pipeline (masquerading as encoder_hidden_states in the pattern), they are preserved as-is without crashing due to invalid tensor operations (like .contiguous() or subtraction).

  2. BagelCachedBlocks: Inherits from CachedBlocks_Pattern_0_1_2.

    • Signature Mapping: It handles the argument mapping between cache-dit's expected encoder_hidden_states and Bagel's past_key_values. The forward method intercepts past_key_values from kwargs, and internal block call methods (call_Fn/Mn/Bn_blocks) ensure the underlying Bagel decoder layer receives arguments in the correct keyword format preventing TypeError.
    • Safe Residuals: Overrides call_Mn_blocks and compute_or_prune to skip residual calculation for non-tensor past_key_values, protecting the NaiveCache.
  3. BagelCachedAdapter: Inherits from CachedAdapter. It serves as the entry point, orchestrating the instantiation of the custom BagelCachedContextManager and BagelCachedBlocks instead of the default ones.

These adapters are defined in vllm_omni/diffusion/cache/cache_dit_backend.py and applied specifically for Bagel pipelines.

Test Plan

Test Command:

import torch
import time
from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion

def run_test(cache_backend=None, cache_config=None, run_name="Baseline"):
    pipeline = OmniDiffusion(
        model="../models/BAGEL-7B-MoT",
        cache_backend=cache_backend,
        cache_config=cache_config
    )
    prompt = "A woman wear a blue dress"
    start_gen = time.time()
    result = pipeline.generate(
        prompt, 
        seed=52
    )
    end_gen = time.time()
    gen_time = end_gen - start_gen
    print(f"Generation Time: {gen_time:.4f}s")
    
    output_path = f"output_bagel_{run_name.lower().replace(' ', '_')}.png"
    result.images[0].save(output_path)

def main():
    # 1. Baseline (No Cache)
    baseline_time = run_test(cache_backend=None, cache_config=None, run_name="Baseline (No Cache)")
    
    cache_config = {
        # DBCache parameters [cache-dit only]
        "Fn_compute_blocks": 1,  # Optimized for single-transformer models
        "Bn_compute_blocks": 0,  # Number of backward compute blocks
        "max_warmup_steps": 4,  # Maximum warmup steps (works for few-step models)
        "residual_diff_threshold": 0.24,  # Higher threshold for more aggressive caching
        "max_continuous_cached_steps": 3,  # Limit to prevent precision degradation
        # TaylorSeer parameters [cache-dit only]
        "enable_taylorseer": False,  # Disabled by default (not suitable for few-step models)
        "taylorseer_order": 1,  # TaylorSeer polynomial order
        # SCM (Step Computation Masking) parameters [cache-dit only]
        "scm_steps_mask_policy": None,  # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
        "scm_steps_policy": "dynamic",  # SCM steps policy: "dynamic" or "static"
    }
    run_test(cache_backend="cache_dit", cache_config=cache_config, run_name="Cache-Dit")

if __name__ == "__main__":
    main()

[Optional] You can directly use examples/offline_inference/text_to_image/text_to_image.py

python examples/offline_inference/text_to_image/text_to_image.py --model ../models/BAGEL-7B-MoT --cache_backend cache_dit

Test Result

Performance Comparison:

Metric Baseline (No Cache) Cache-Dit Enabled Speedup
Generation Time 9.28s 4.41s ~2.1x

Visual Validation:


Baseline (No Cache)

Cache-Dit Enabled

Signed-off-by: princepride <wangzhipeng628@gmail.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 582c8b1fd3

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +540 to +543
prefix=(
f"{self.cache_prefix}_{block_id}_Bn_residual"
if self.context_manager.is_cache_residual()
else f"{self.cache_prefix}_Bn_hidden_states"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fix prune cache key mismatch for non-residual mode

When pruning is used with non-residual caching, apply_prune looks up the Bn buffer under f"{self.cache_prefix}_Bn_hidden_states", but the buffer is stored later with f"{self.cache_prefix}_{block_id}_Bn_hidden_states". This mismatch means the prune path will fail to find the cached hidden states (likely raising an assertion or reusing the wrong data) whenever is_cache_residual() is false and _maybe_prune() returns true (e.g., SCM/pruning enabled). Aligning the prefix with the one used in set_Bn_buffer avoids the incorrect cache lookup.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I directly copied from cache-dit

return refresh_cache_context


class BagelCachedContextManager(CachedContextManager):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's feasible to add it directly in cache-dit, so that cache-dit repo can support bagel natively as well? @DefTruth

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good to me. Feel free to send a PR to cache-dit, and I can help review it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bagel's original code is extremely complex, and it doesn't name its modules according to the standard Dit model. I can only guarantee that the code I submit will support Bagel's acceleration on vLLM-Omni.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I believe hunyuan image will also benefit from this new feature.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SamitHuang @hsliuustc0106 What do you think?

Copy link
Copy Markdown
Collaborator

@SamitHuang SamitHuang Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, diffusers doesn't support bagel or HYI3.0. CacheDiT doesn't plan to support image/video generation models in transformers, right? @DefTruth

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since cache-dit is built on top of the diffusers library, we have long focused mainly on the models supported in diffusers.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the conclusion is still keeping the change in vllm-omni here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the conclusion is still keeping the change in vllm-omni here?

I'm afraid so.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

please update the diffusion acceleration md for this support

Signed-off-by: princepride <wangzhipeng628@gmail.com>
@princepride
Copy link
Copy Markdown
Collaborator Author

@hsliuustc0106 Already update docs.

@princepride
Copy link
Copy Markdown
Collaborator Author

@tjtanaa amd CI failed, PTAL.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 merged commit 0888520 into vllm-project:main Jan 17, 2026
7 checks passed
erfgss pushed a commit to erfgss/vllm-omni that referenced this pull request Jan 19, 2026
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: Chen Yang <2082464740@qq.com>
with1015 pushed a commit to with1015/vllm-omni that referenced this pull request Jan 20, 2026
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants