diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 7fc6ec832f9..efc5010e04f 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -73,6 +73,11 @@ def parse_args() -> argparse.Namespace: "Default: None (no cache acceleration)." ), ) + parser.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) parser.add_argument( "--ulysses_degree", type=int, @@ -166,6 +171,7 @@ def main(): vae_use_tiling=vae_use_tiling, cache_backend=args.cache_backend, cache_config=cache_config, + enable_cache_dit_summary=args.enable_cache_dit_summary, parallel_config=parallel_config, enforce_eager=args.enforce_eager, enable_cpu_offload=args.enable_cpu_offload, diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 054db95b3e9..49522278f96 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -36,8 +36,29 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." ) + parser.add_argument( + "--cache_backend", + type=str, + default=None, + choices=["cache_dit"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer). " + "Default: None (no cache acceleration)." + ), + ) + parser.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) parser.add_argument("--output", type=str, default="wan22_output.mp4", help="Path to save the video (mp4).") parser.add_argument("--fps", type=int, default=24, help="Frames per second for the output video.") + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) parser.add_argument( "--enable-cpu-offload", action="store_true", @@ -55,11 +76,6 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ring sequence parallelism.", ) - parser.add_argument( - "--enforce_eager", - action="store_true", - help="Disable torch.compile and force eager execution.", - ) return parser.parse_args() @@ -72,6 +88,24 @@ def main(): vae_use_slicing = is_npu() vae_use_tiling = is_npu() + # Wan2.2 cache-dit tuning (from cache-dit examples and cache_alignment). + cache_config = None + if args.cache_backend == "cache_dit": + cache_config = { + # DBCache parameters [cache-dit only] + "Fn_compute_blocks": 1, # Optimized for single-transformer models + "Bn_compute_blocks": 0, # Number of backward compute blocks + "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models) + "max_cached_steps": 20, + "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching + "max_continuous_cached_steps": 3, # Limit to prevent precision degradation + # TaylorSeer parameters [cache-dit only] + "enable_taylorseer": False, # Disabled by default (not suitable for few-step models) + "taylorseer_order": 1, # TaylorSeer polynomial order + # SCM (Step Computation Masking) parameters [cache-dit only] + "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" + "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" + } # Configure parallel settings (only SP is supported for Wan) # Note: cfg_parallel and tensor_parallel are not implemented for Wan models parallel_config = DiffusionParallelConfig( @@ -88,6 +122,9 @@ def main(): vae_use_tiling=vae_use_tiling, boundary_ratio=args.boundary_ratio, flow_shift=args.flow_shift, + cache_backend=args.cache_backend, + cache_config=cache_config, + enable_cache_dit_summary=args.enable_cache_dit_summary, enable_cpu_offload=args.enable_cpu_offload, parallel_config=parallel_config, enforce_eager=args.enforce_eager, diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index ff2c9800256..c922121e9ae 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -28,6 +28,13 @@ logger = init_logger(__name__) +# Small helper to centralize cache-dit summaries. +def cache_summary(pipeline: Any, details: bool = True) -> None: + cache_dit.summary(pipeline.transformer, details=details) + if hasattr(pipeline, "transformer_2"): + cache_dit.summary(pipeline.transformer_2, details=details) + + # Registry of custom cache-dit enablers for specific models # Maps pipeline names to their cache-dit enablement functions # Models in this registry require custom handling (e.g., dual-transformer architectures) @@ -71,53 +78,46 @@ def enable_cache_for_wan22(pipeline: Any, cache_config: Any) -> Callable[[int], A refresh function that can be called to update cache context with new num_inference_steps. """ - # Build DBCacheConfig for primary transformer - primary_cache_config = _build_db_cache_config(cache_config) - - # FIXME: secondary cache shares the same config with primary cache for now, - # but we should support different config for secondary transformer in the future - # Build DBCacheConfig for secondary transformer (can use same or different config) - secondary_cache_config = _build_db_cache_config(cache_config) - - # Build calibrator configs if TaylorSeer is enabled - primary_calibrator = None - secondary_calibrator = None - if cache_config.enable_taylorseer: - taylorseer_order = cache_config.taylorseer_order - primary_calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) - secondary_calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) - logger.info(f"TaylorSeer enabled with order={taylorseer_order}") - - # Build ParamsModifier for each transformer - primary_modifier = ParamsModifier( - cache_config=primary_cache_config, - calibrator_config=primary_calibrator, - ) - secondary_modifier = ParamsModifier( - cache_config=secondary_cache_config, - calibrator_config=secondary_calibrator, - ) - - logger.info( - "Enabling cache-dit on Wan2.2 dual transformers with BlockAdapter: " - f"Fn={primary_cache_config.Fn_compute_blocks}, " - f"Bn={primary_cache_config.Bn_compute_blocks}, " - f"W={primary_cache_config.max_warmup_steps}, " - ) - - transformer = pipeline.transformer - transformer_2 = pipeline.transformer_2 - transformer_blocks = transformer.blocks - transformer_2_blocks = transformer_2.blocks - # Enable cache-dit using BlockAdapter for both transformers simultaneously cache_dit.enable_cache( BlockAdapter( - transformer=[transformer, transformer_2], - blocks=[transformer_blocks, transformer_2_blocks], - forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2], - params_modifiers=[primary_modifier, secondary_modifier], + transformer=[ + pipeline.transformer, + pipeline.transformer_2, + ], + blocks=[ + pipeline.transformer.blocks, + pipeline.transformer_2.blocks, + ], + forward_pattern=[ + ForwardPattern.Pattern_2, + ForwardPattern.Pattern_2, + ], + params_modifiers=[ + # high-noise transformer only have 30% steps + ParamsModifier( + cache_config=DBCacheConfig().reset( + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + ), + ), + ParamsModifier( + cache_config=DBCacheConfig().reset( + max_warmup_steps=2, + max_cached_steps=20, + ), + ), + ], has_separate_cfg=True, ), + cache_config=DBCacheConfig( + Fn_compute_blocks=cache_config.Fn_compute_blocks, + Bn_compute_blocks=cache_config.Bn_compute_blocks, + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, + residual_diff_threshold=cache_config.residual_diff_threshold, + num_inference_steps=None, + ), ) # from https://github.com/vipshop/cache-dit/pull/542 @@ -133,8 +133,8 @@ def _split_inference_steps(num_inference_steps: int) -> tuple[int, int]: Returns: A tuple of (num_high_noise_steps, num_low_noise_steps). """ - if pipeline.config.boundary_ratio is not None: - boundary_timestep = pipeline.config.boundary_ratio * pipeline.scheduler.config.num_train_timesteps + if pipeline.boundary_ratio is not None: + boundary_timestep = pipeline.boundary_ratio * pipeline.scheduler.config.num_train_timesteps else: boundary_timestep = None @@ -158,10 +158,21 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool pipeline: The Wan2.2 pipeline instance. num_inference_steps: New number of inference steps. """ + num_high_noise_steps, num_low_noise_steps = _split_inference_steps(num_inference_steps) # Refresh context for high-noise transformer if cache_config.scm_steps_mask_policy is None: - cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose) + # cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose) + cache_dit.refresh_context( + pipeline.transformer, + num_inference_steps=num_high_noise_steps, + verbose=verbose, + ) + cache_dit.refresh_context( + pipeline.transformer_2, + num_inference_steps=num_low_noise_steps, + verbose=verbose, + ) else: cache_dit.refresh_context( pipeline.transformer, @@ -175,10 +186,6 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool verbose=verbose, ) - # Refresh context for low-noise transformer - if cache_config.scm_steps_mask_policy is None: - cache_dit.refresh_context(pipeline.transformer_2, num_inference_steps=num_low_noise_steps, verbose=verbose) - else: cache_dit.refresh_context( pipeline.transformer_2, cache_config=DBCacheConfig().reset( @@ -790,8 +797,10 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool # Register custom cache-dit enablers after function definitions CUSTOM_DIT_ENABLERS.update( { - "WanPipeline": enable_cache_for_wan22, - # "FluxPipeline": enable_cache_for_flux, + "Wan22Pipeline": enable_cache_for_wan22, + "Wan22I2VPipeline": enable_cache_for_wan22, + "Wan22TI2VPipeline": enable_cache_for_wan22, + "FluxPipeline": enable_cache_for_flux, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, "StableDiffusion3Pipeline": enable_cache_for_sd3, diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 1a98462aaea..286ab9153dc 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -258,6 +258,7 @@ class OmniDiffusionConfig: # Cache backend configuration (NEW) cache_backend: str = "none" # "tea_cache", "deep_cache", etc. cache_config: DiffusionCacheConfig | dict[str, Any] = field(default_factory=dict) + enable_cache_dit_summary: bool = False # Distributed executor backend distributed_executor_backend: str = "mp" diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py index bf56df15590..0401655a858 100644 --- a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py @@ -20,6 +20,7 @@ from vllm.logger import init_logger from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes +from vllm_omni.diffusion.cache.cache_dit_backend import cache_summary from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.compile import regionally_compile from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig @@ -284,11 +285,19 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: req.sampling_params.generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) # Refresh cache context if needed - if self.cache_backend is not None and self.cache_backend.is_enabled(): + if ( + not getattr(req, "skip_cache_refresh", False) + and self.cache_backend is not None + and self.cache_backend.is_enabled() + ): self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps) with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): with record_function("pipeline_forward"): output = self.pipeline.forward(req) + # NOTE: + if self.od_config.cache_backend == "cache_dit" and self.od_config.enable_cache_dit_summary: + cache_summary(self.pipeline, details=True) + return output diff --git a/vllm_omni/diffusion/worker/npu/npu_worker.py b/vllm_omni/diffusion/worker/npu/npu_worker.py index 446c29cae4d..c4d2c60e553 100644 --- a/vllm_omni/diffusion/worker/npu/npu_worker.py +++ b/vllm_omni/diffusion/worker/npu/npu_worker.py @@ -132,7 +132,11 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi req.generator = torch.Generator(device=self.device).manual_seed(req.seed) # Refresh cache context if needed - if self.cache_backend is not None and self.cache_backend.is_enabled(): + if ( + not getattr(req, "skip_cache_refresh", False) + and self.cache_backend is not None + and self.cache_backend.is_enabled() + ): self.cache_backend.refresh(self.pipeline, req.num_inference_steps) with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): output = self.pipeline.forward(req) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index a78710fa2fd..179f58ed22b 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -162,6 +162,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "vae_use_tiling": kwargs.get("vae_use_tiling", False), "cache_backend": cache_backend, "cache_config": cache_config, + "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False), "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), "enforce_eager": kwargs.get("enforce_eager", False), }, diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index c3a37e3c82e..bbd16711194 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -175,6 +175,11 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="JSON string of cache configuration (e.g., '{\"rel_l1_thresh\": 0.2}').", ) + omni_config_group.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) # VAE memory optimization parameters omni_config_group.add_argument(