From 582c8b1fd3d74c96cbfd03c9a3879dfb980d01e9 Mon Sep 17 00:00:00 2001 From: princepride Date: Sun, 11 Jan 2026 08:36:18 +0000 Subject: [PATCH 1/2] bagel support cache-dit Signed-off-by: princepride --- .../diffusion/cache/cache_dit_backend.py | 387 ++++++++++++++++++ .../models/bagel/bagel_transformer.py | 15 +- 2 files changed, 397 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 3485e2262c9..0c43659c9a6 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -7,11 +7,19 @@ pipelines in vllm-omni, supporting both single and dual-transformer architectures. """ +import functools from collections.abc import Callable +from contextlib import ExitStack from typing import Any, Optional import cache_dit +import torch from cache_dit import BlockAdapter, DBCacheConfig, ForwardPattern, ParamsModifier, TaylorSeerCalibratorConfig +from cache_dit.caching.block_adapters import FakeDiffusionPipeline +from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter +from cache_dit.caching.cache_blocks.pattern_0_1_2 import CachedBlocks_Pattern_0_1_2 +from cache_dit.caching.cache_contexts import BasicCacheConfig +from cache_dit.caching.cache_contexts.cache_manager import CachedContextManager from vllm.logger import init_logger from vllm_omni.diffusion.cache.base import CacheBackend @@ -401,6 +409,384 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +class BagelCachedContextManager(CachedContextManager): + """ + Custom CachedContextManager for Bagel that safely handles NaiveCache objects + (mapped to encoder_hidden_states) by skipping tensor operations on them. + """ + + @torch.compiler.disable + def apply_cache( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + prefix: str = "Bn", + encoder_prefix: str = "Bn_encoder", + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Allow Bn and Fn prefix to be used for residual cache. + if "Bn" in prefix: + hidden_states_prev = self.get_Bn_buffer(prefix) + else: + hidden_states_prev = self.get_Fn_buffer(prefix) + + assert hidden_states_prev is not None, f"{prefix}_buffer must be set before" + + if self.is_cache_residual(): + hidden_states = hidden_states_prev + hidden_states + else: + # If cache is not residual, we use the hidden states directly + hidden_states = hidden_states_prev + + hidden_states = hidden_states.contiguous() + + if encoder_hidden_states is not None: + if "Bn" in encoder_prefix: + encoder_hidden_states_prev = self.get_Bn_encoder_buffer(encoder_prefix) + else: + encoder_hidden_states_prev = self.get_Fn_encoder_buffer(encoder_prefix) + + if encoder_hidden_states_prev is not None: + if self.is_encoder_cache_residual(): + # FIX: Check if encoder_hidden_states is a tensor before adding + if isinstance(encoder_hidden_states, torch.Tensor) and isinstance( + encoder_hidden_states_prev, torch.Tensor + ): + encoder_hidden_states = encoder_hidden_states_prev + encoder_hidden_states + else: + # If encoder cache is not residual, we use the encoder hidden states directly + encoder_hidden_states = encoder_hidden_states_prev + + # FIX: Check if encoder_hidden_states is a tensor before calling contiguous + if isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states.contiguous() + + return hidden_states, encoder_hidden_states + + +class BagelCachedBlocks(CachedBlocks_Pattern_0_1_2): + """ + Custom CachedBlocks for Bagel that safely handles NaiveCache objects + by adding isinstance checks in call_Mn_blocks and compute_or_prune. + """ + + def call_Mn_blocks( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + *args, + **kwargs, + ): + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + for block in self._Mn_blocks(): + hidden_states = block( + hidden_states, + encoder_hidden_states, + *args, + **kwargs, + ) + hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states) + + # compute hidden_states residual + hidden_states = hidden_states.contiguous() + + hidden_states_residual = hidden_states - original_hidden_states + + if ( + encoder_hidden_states is not None + and original_encoder_hidden_states is not None + and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check + ): + encoder_hidden_states = encoder_hidden_states.contiguous() + encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + else: + encoder_hidden_states_residual = None + + return ( + hidden_states, + encoder_hidden_states, + hidden_states_residual, + encoder_hidden_states_residual, + ) + + def compute_or_prune( + self, + block_id: int, # Block index in the transformer blocks + # Below are the inputs to the block + block, # The transformer block to be executed + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + *args, + **kwargs, + ): + # NOTE: Although Bagel likely won't use pruning, implementing safe version just in case. + # Copy-pasted from original but adding checks. + + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + + can_use_prune = self._maybe_prune( + block_id, + hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Fn_original", + ) + + torch._dynamo.graph_break() + if can_use_prune: + self.context_manager.add_pruned_step() + hidden_states, encoder_hidden_states = self.context_manager.apply_prune( + hidden_states, + encoder_hidden_states, + prefix=( + f"{self.cache_prefix}_{block_id}_Bn_residual" + if self.context_manager.is_cache_residual() + else f"{self.cache_prefix}_Bn_hidden_states" + ), + encoder_prefix=( + f"{self.cache_prefix}_{block_id}_Bn_encoder_residual" + if self.context_manager.is_encoder_cache_residual() + else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states" + ), + ) + torch._dynamo.graph_break() + else: + # Normal steps: Compute the block and cache the residuals. + hidden_states = block( + hidden_states, + encoder_hidden_states, + *args, + **kwargs, + ) + hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states) + if not self._skip_prune(block_id): + hidden_states = hidden_states.contiguous() + hidden_states_residual = hidden_states - original_hidden_states + + if ( + encoder_hidden_states is not None + and original_encoder_hidden_states is not None + and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check + ): + encoder_hidden_states = encoder_hidden_states.contiguous() + encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + else: + encoder_hidden_states_residual = None + + self.context_manager.set_Fn_buffer( + original_hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Fn_original", + ) + if self.context_manager.is_cache_residual(): + self.context_manager.set_Bn_buffer( + hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_residual", + ) + else: + self.context_manager.set_Bn_buffer( + hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states", + ) + if encoder_hidden_states_residual is not None: + if self.context_manager.is_encoder_cache_residual(): + self.context_manager.set_Bn_encoder_buffer( + encoder_hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual", + ) + else: + self.context_manager.set_Bn_encoder_buffer( + encoder_hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states", + ) + torch._dynamo.graph_break() + + return hidden_states, encoder_hidden_states + + +class BagelCachedAdapter(CachedAdapter): + """ + Custom CachedAdapter for Bagel that uses BagelCachedContextManager and BagelCachedBlocks. + """ + + @classmethod + def create_context( + cls, + block_adapter: BlockAdapter, + **context_kwargs, + ) -> tuple[list[str], list[dict[str, Any]]]: + # Override to use BagelCachedContextManager + + BlockAdapter.assert_normalized(block_adapter) + + if BlockAdapter.is_cached(block_adapter.pipe): + return block_adapter.pipe + + # Check context_kwargs + context_kwargs = cls.check_context_kwargs(block_adapter, **context_kwargs) + + # Each Pipeline should have it's own context manager instance. + cache_config: BasicCacheConfig = context_kwargs.get("cache_config", None) + assert cache_config is not None, "cache_config can not be None." + + # Apply cache on pipeline: wrap cache context + pipe_cls_name = block_adapter.pipe.__class__.__name__ + + # USE CUSTOM CONTEXT MANAGER + context_manager = BagelCachedContextManager( + name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}", + persistent_context=isinstance(block_adapter.pipe, FakeDiffusionPipeline), + ) + + flatten_contexts, contexts_kwargs = cls.modify_context_params(block_adapter, **context_kwargs) + + block_adapter.pipe._context_manager = context_manager # instance level + + if not context_manager.persistent_context: + original_call = block_adapter.pipe.__class__.__call__ + + @functools.wraps(original_call) + def new_call(self, *args, **kwargs): + with ExitStack() as stack: + # cache context will be reset for each pipe inference + for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs): + stack.enter_context( + context_manager.enter_context( + context_manager.reset_context( + context_name, + **context_kwargs, + ), + ) + ) + outputs = original_call(self, *args, **kwargs) + cls.apply_stats_hooks(block_adapter) + return outputs + + block_adapter.pipe.__class__.__call__ = new_call + block_adapter.pipe.__class__._original_call = original_call + + else: + # Init persistent cache context for transformer + for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs): + context_manager.reset_context( + context_name, + **context_kwargs, + ) + + block_adapter.pipe.__class__._is_cached = True + + cls.apply_params_hooks(block_adapter, contexts_kwargs) + + return flatten_contexts, contexts_kwargs + + @classmethod + def collect_unified_blocks( + cls, + block_adapter: BlockAdapter, + contexts_kwargs: list[dict], + ) -> list[dict[str, torch.nn.ModuleList]]: + # Override to use BagelCachedBlocks + + BlockAdapter.assert_normalized(block_adapter) + + total_cached_blocks: list[dict[str, torch.nn.ModuleList]] = [] + assert hasattr(block_adapter.pipe, "_context_manager") + # Skipping isinstance check for ContextManager._supported_managers to avoid import issues + + for i in range(len(block_adapter.transformer)): + unified_blocks_bind_context = {} + for j in range(len(block_adapter.blocks[i])): + cache_config: BasicCacheConfig = contexts_kwargs[i * len(block_adapter.blocks[i]) + j]["cache_config"] + + # Directly instantiate BagelCachedBlocks + unified_blocks_bind_context[block_adapter.unique_blocks_name[i][j]] = torch.nn.ModuleList( + [ + BagelCachedBlocks( + # 0. Transformer blocks configuration + block_adapter.blocks[i][j], + transformer=block_adapter.transformer[i], + forward_pattern=block_adapter.forward_pattern[i][j], + check_forward_pattern=block_adapter.check_forward_pattern, + check_num_outputs=block_adapter.check_num_outputs, + # 1. Cache/Prune context configuration + cache_prefix=block_adapter.blocks_name[i][j], + cache_context=block_adapter.unique_blocks_name[i][j], + context_manager=block_adapter.pipe._context_manager, + cache_type=cache_config.cache_type, + ) + ] + ) + + total_cached_blocks.append(unified_blocks_bind_context) + + return total_cached_blocks + + +def enable_cache_for_bagel(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Bagel model (via OmniDiffusion pipeline). + + Args: + pipeline: The OmniDiffusion pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + # Build DBCacheConfig + db_cache_config = _build_db_cache_config(cache_config) + + # Build calibrator config if TaylorSeer is enabled + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Access the transformer: BagelPipeline -> Qwen2MoTForCausalLM -> Qwen2MoTModel + # BagelPipeline has self.language_model which is Qwen2MoTForCausalLM + # Qwen2MoTForCausalLM has self.model which is Qwen2MoTModel + transformer = pipeline.language_model.model + + logger.info( + f"Enabling cache-dit on Bagel transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit on the transformer + # Pattern_0 corresponds to (hidden_states, encoder_hidden_states) input, output + # Custom adapter for Bagel to handle NaiveCache correctly + # from vllm_omni.diffusion.cache.bagel_cache_adapter import BagelCachedAdapter # No longer needed + BagelCachedAdapter.apply( + BlockAdapter( + transformer=transformer, + blocks=transformer.layers, + forward_pattern=ForwardPattern.Pattern_0, + ), + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + transformer = pipeline.language_model.model + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + # Register custom cache-dit enablers after function definitions CUSTOM_DIT_ENABLERS.update( { @@ -409,6 +795,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, "StableDiffusion3Pipeline": enable_cache_for_sd3, + "BagelPipeline": enable_cache_for_bagel, } ) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 7950389ef1f..256f25e0839 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -314,10 +314,12 @@ def __init__( def forward( self, - packed_query_sequence: torch.Tensor, - query_lens: torch.Tensor, - packed_query_position_embeddings: torch.Tensor, - packed_query_indexes: torch.Tensor, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + packed_query_sequence: torch.Tensor | None = None, + query_lens: torch.Tensor = None, + packed_query_position_embeddings: torch.Tensor = None, + packed_query_indexes: torch.Tensor = None, past_key_values: NaiveCache | None = None, key_values_lens: torch.Tensor | None = None, packed_key_value_indexes: torch.Tensor | None = None, @@ -327,6 +329,8 @@ def forward( packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: + if packed_query_sequence is None: + packed_query_sequence = hidden_states residual = packed_query_sequence if mode == "und": packed_query_sequence = self.input_layernorm(packed_query_sequence) @@ -437,7 +441,8 @@ def forward( for layer_idx, decoder_layer in enumerate(self.layers): packed_query_sequence, past_key_values = decoder_layer( - packed_query_sequence=packed_query_sequence, + hidden_states=packed_query_sequence, + encoder_hidden_states=None, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, From d279dbd51a203f156250087a91cd473560f7143c Mon Sep 17 00:00:00 2001 From: princepride Date: Fri, 16 Jan 2026 12:29:44 +0000 Subject: [PATCH 2/2] update docs Signed-off-by: princepride --- docs/user_guide/diffusion_acceleration.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 8f78ae32e50..df34af4e558 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -49,6 +49,7 @@ The following table shows which models are currently supported by each accelerat | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ |❌ | ❌ | ❌ | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ❌ | ✅ | ❌ | ❌ | ❌ | ### VideoGen