diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py index 3cb74f82cca..13bcd23f55a 100644 --- a/examples/offline_inference/glm_image/end2end.py +++ b/examples/offline_inference/glm_image/end2end.py @@ -238,8 +238,24 @@ def main(args: argparse.Namespace) -> None: if args.negative_prompt: prompt_dict["negative_prompt"] = args.negative_prompt + # Build cache-dit config if requested + cache_config = None + if args.cache_backend == "cache_dit": + cache_config = { + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.24, + "max_continuous_cached_steps": 3, + "enable_taylorseer": False, + "taylorseer_order": 1, + "scm_steps_mask_policy": None, + "scm_steps_policy": "dynamic", + } + # Initialize Omni with multistage config print("\nInitializing Omni with multistage pipeline...") + print(f" Cache backend: {args.cache_backend or 'None (no acceleration)'}") start_time = time.time() omni = Omni( @@ -247,6 +263,9 @@ def main(args: argparse.Namespace) -> None: stage_configs_path=config_path, log_stats=args.enable_stats, stage_init_timeout=args.stage_init_timeout, + cache_backend=args.cache_backend, + cache_config=cache_config, + enable_cache_dit_summary=getattr(args, "enable_cache_dit_summary", False), enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, ) @@ -444,6 +463,20 @@ def parse_args() -> argparse.Namespace: help="Number of images to generate (default: 1)", ) + # Cache acceleration + parser.add_argument( + "--cache-backend", + type=str, + default=None, + choices=["cache_dit"], + help="Cache backend for DiT acceleration. Default: None (no cache).", + ) + parser.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) + # Runtime options parser.add_argument( "--enable-stats", diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 831d55f9850..e5337be1276 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -61,6 +61,8 @@ def _build_db_cache_config(cache_config: Any) -> DBCacheConfig: max_cached_steps=cache_config.max_cached_steps, max_continuous_cached_steps=cache_config.max_continuous_cached_steps, residual_diff_threshold=cache_config.residual_diff_threshold, + force_refresh_step_hint=cache_config.force_refresh_step_hint, + force_refresh_step_policy=cache_config.force_refresh_step_policy, ) @@ -1091,6 +1093,36 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for GLM-Image pipeline. + + GLM-Image processes prompt and image by calling the transformer before the + denoising loop. When an input image is provided (editing mode), the cache must + be force-refreshed after the preprocessing step so stale hidden states are + discarded. Set force_refresh_step_hint = 1 for editing, None for text-to-image. + """ + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=cache_config.taylorseer_order) + logger.info(f"TaylorSeer enabled with order={cache_config.taylorseer_order}") + + logger.info( + f"Enabling cache-dit on GLM-Image transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + f"force_refresh_step_hint={db_cache_config.force_refresh_step_hint}, " + ) + + cache_dit.enable_cache( + pipeline.transformer, + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for Flux.2-dev pipeline. @@ -1180,6 +1212,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "LTX2Pipeline": enable_cache_for_ltx2, "LTX2ImageToVideoPipeline": enable_cache_for_ltx2, "BagelPipeline": enable_cache_for_bagel, + "GlmImagePipeline": enable_cache_for_glm_image, "Flux2Pipeline": enable_cache_for_flux2, } ) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index f15aae6a026..052149046da 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -264,6 +264,12 @@ class DiffusionCacheConfig: # Used by cache-dit for scm mask generation. If this value changes during inference, # we will re-generate the scm mask and refresh the cache context. num_inference_steps: int | None = None + # Force refresh the cache at a specific step index hint, useful for models like + # GLM-Image (image preprocessing step in editing mode). + force_refresh_step_hint: int | None = None + # Policy for force refresh: "once" refreshes only at the hint step, + # "repeat" refreshes every force_refresh_step_hint steps. + force_refresh_step_policy: str = "once" # Additional parameters that may be passed but not explicitly defined _extra_params: dict[str, Any] = field(default_factory=dict, repr=False)