diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 3daf883e0d..d5397dd166 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -1170,41 +1170,85 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context -def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]: - """Enable cache-dit for GLM-Image pipeline. +def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Flux.2-dev pipeline. - GLM-Image processes prompt and image by calling the transformer before the - denoising loop. When an input image is provided (editing mode), the cache must - be force-refreshed after the preprocessing step so stale hidden states are - discarded. Set force_refresh_step_hint = 1 for editing, None for text-to-image. + Args: + pipeline: The Flux2 pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + Returns: + A refresh function that can be called with a new ``num_inference_steps`` + to update the cache context for the pipeline. """ + # Build DBCacheConfig for transformer db_cache_config = _build_db_cache_config(cache_config) - calibrator_config = None + calibrator = None if cache_config.enable_taylorseer: - calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=cache_config.taylorseer_order) - logger.info(f"TaylorSeer enabled with order={cache_config.taylorseer_order}") + taylorseer_order = cache_config.taylorseer_order + calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Build ParamsModifier for transformer + modifier = ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator, + ) logger.info( - f"Enabling cache-dit on GLM-Image transformer: " + f"Enabling cache-dit on Flux transformer with BlockAdapter: " f"Fn={db_cache_config.Fn_compute_blocks}, " f"Bn={db_cache_config.Bn_compute_blocks}, " f"W={db_cache_config.max_warmup_steps}, " - f"force_refresh_step_hint={db_cache_config.force_refresh_step_hint}, " ) + # Enable cache-dit using BlockAdapter for transformer cache_dit.enable_cache( - pipeline.transformer, + ( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2], + params_modifiers=[modifier], + ) + ), cache_config=db_cache_config, - calibrator_config=calibrator_config, ) + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. -def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]: - """Enable cache-dit for Flux.2-dev pipeline. + Args: + pipeline: The Flux2 pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + +def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for GlmImage pipeline. Args: - pipeline: The Flux2 pipeline instance. + pipeline: The GlmImage pipeline instance. cache_config: DiffusionCacheConfig instance with cache configuration. Returns: A refresh function that can be called with a new ``num_inference_steps`` @@ -1226,23 +1270,25 @@ def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], ) logger.info( - f"Enabling cache-dit on Flux transformer with BlockAdapter: " + f"Enabling cache-dit on GlmImage transformer with BlockAdapter: " f"Fn={db_cache_config.Fn_compute_blocks}, " f"Bn={db_cache_config.Bn_compute_blocks}, " f"W={db_cache_config.max_warmup_steps}, " ) # Enable cache-dit using BlockAdapter for transformer + # Note: We don't use patch_functor here because it's designed for diffusers' GlmImage, + # and our vllm-omni implementation has a different forward signature. + # We use ForwardPattern.Pattern_0 because our block returns (hidden_states, encoder_hidden_states) cache_dit.enable_cache( ( BlockAdapter( transformer=pipeline.transformer, - blocks=[ - pipeline.transformer.transformer_blocks, - pipeline.transformer.single_transformer_blocks, - ], - forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2], + blocks=pipeline.transformer.transformer_blocks, + forward_pattern=ForwardPattern.Pattern_0, params_modifiers=[modifier], + patch_functor=None, + has_separate_cfg=True, ) ), cache_config=db_cache_config, @@ -1252,7 +1298,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool """Refresh cache context for the transformer with new num_inference_steps. Args: - pipeline: The Flux2 pipeline instance. + pipeline: The GlmImage pipeline instance. num_inference_steps: New number of inference steps. """ if cache_config.scm_steps_mask_policy is None: