From 17d0115909efabec4832c059118ea824d098c5a3 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Mon, 26 Jan 2026 06:32:32 +0000 Subject: [PATCH 1/4] add cachedit to wan2.2 Signed-off-by: samithuang <285365963@qq.com> --- .../text_to_video/text_to_video.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 066714c549b..8e8f55b3a8e 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -33,6 +33,17 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." ) + parser.add_argument( + "--cache_backend", + type=str, + default=None, + choices=["cache_dit"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer). " + "Default: None (no cache acceleration)." + ), + ) parser.add_argument("--output", type=str, default="wan22_output.mp4", help="Path to save the video (mp4).") parser.add_argument("--fps", type=int, default=24, help="Frames per second for the output video.") parser.add_argument( @@ -52,6 +63,24 @@ def main(): vae_use_slicing = is_npu() vae_use_tiling = is_npu() + # Wan2.2 cache-dit tuning (from cache-dit examples and cache_alignment). + cache_config = None + if args.cache_backend == "cache_dit": + cache_config = { + # DBCache parameters [cache-dit only] + "Fn_compute_blocks": 1, # Optimized for single-transformer models + "Bn_compute_blocks": 0, # Number of backward compute blocks + "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models) + "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching + "max_continuous_cached_steps": 3, # Limit to prevent precision degradation + # TaylorSeer parameters [cache-dit only] + "enable_taylorseer": False, # Disabled by default (not suitable for few-step models) + "taylorseer_order": 1, # TaylorSeer polynomial order + # SCM (Step Computation Masking) parameters [cache-dit only] + "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" + "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" + } + # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) @@ -61,6 +90,8 @@ def main(): vae_use_tiling=vae_use_tiling, boundary_ratio=args.boundary_ratio, flow_shift=args.flow_shift, + cache_backend=args.cache_backend, + cache_config=cache_config, enable_cpu_offload=args.enable_cpu_offload, ) From b19b0805c1b8fee6143caff40f5123a87125365c Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 28 Jan 2026 09:15:01 +0000 Subject: [PATCH 2/4] fix wan cachedit Signed-off-by: samithuang <285365963@qq.com> --- .../text_to_image/text_to_image.py | 6 + .../text_to_video/text_to_video.py | 13 ++ .../diffusion/cache/cache_dit_backend.py | 111 ++++++++++-------- vllm_omni/diffusion/data.py | 1 + vllm_omni/diffusion/diffusion_engine.py | 1 + vllm_omni/diffusion/request.py | 3 + .../worker/gpu_diffusion_model_runner.py | 5 + vllm_omni/diffusion/worker/npu/npu_worker.py | 6 +- vllm_omni/entrypoints/async_omni.py | 1 + vllm_omni/entrypoints/cli/serve.py | 5 + 10 files changed, 100 insertions(+), 52 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 428c2437e6f..bae726dba24 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -72,6 +72,11 @@ def parse_args() -> argparse.Namespace: "Default: None (no cache acceleration)." ), ) + parser.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) parser.add_argument( "--ulysses_degree", type=int, @@ -165,6 +170,7 @@ def main(): vae_use_tiling=vae_use_tiling, cache_backend=args.cache_backend, cache_config=cache_config, + enable_cache_dit_summary=args.enable_cache_dit_summary, parallel_config=parallel_config, enforce_eager=args.enforce_eager, enable_cpu_offload=args.enable_cpu_offload, diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 8e8f55b3a8e..6796dc45b28 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -44,8 +44,18 @@ def parse_args() -> argparse.Namespace: "Default: None (no cache acceleration)." ), ) + parser.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) parser.add_argument("--output", type=str, default="wan22_output.mp4", help="Path to save the video (mp4).") parser.add_argument("--fps", type=int, default=24, help="Frames per second for the output video.") + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) parser.add_argument( "--enable-cpu-offload", action="store_true", @@ -71,6 +81,7 @@ def main(): "Fn_compute_blocks": 1, # Optimized for single-transformer models "Bn_compute_blocks": 0, # Number of backward compute blocks "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models) + "max_cached_steps": 20, "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching "max_continuous_cached_steps": 3, # Limit to prevent precision degradation # TaylorSeer parameters [cache-dit only] @@ -92,6 +103,8 @@ def main(): flow_shift=args.flow_shift, cache_backend=args.cache_backend, cache_config=cache_config, + enable_cache_dit_summary=args.enable_cache_dit_summary, + enforce_eager=args.enforce_eager, enable_cpu_offload=args.enable_cpu_offload, ) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 0c43659c9a6..c922121e9ae 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -28,6 +28,13 @@ logger = init_logger(__name__) +# Small helper to centralize cache-dit summaries. +def cache_summary(pipeline: Any, details: bool = True) -> None: + cache_dit.summary(pipeline.transformer, details=details) + if hasattr(pipeline, "transformer_2"): + cache_dit.summary(pipeline.transformer_2, details=details) + + # Registry of custom cache-dit enablers for specific models # Maps pipeline names to their cache-dit enablement functions # Models in this registry require custom handling (e.g., dual-transformer architectures) @@ -71,53 +78,46 @@ def enable_cache_for_wan22(pipeline: Any, cache_config: Any) -> Callable[[int], A refresh function that can be called to update cache context with new num_inference_steps. """ - # Build DBCacheConfig for primary transformer - primary_cache_config = _build_db_cache_config(cache_config) - - # FIXME: secondary cache shares the same config with primary cache for now, - # but we should support different config for secondary transformer in the future - # Build DBCacheConfig for secondary transformer (can use same or different config) - secondary_cache_config = _build_db_cache_config(cache_config) - - # Build calibrator configs if TaylorSeer is enabled - primary_calibrator = None - secondary_calibrator = None - if cache_config.enable_taylorseer: - taylorseer_order = cache_config.taylorseer_order - primary_calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) - secondary_calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) - logger.info(f"TaylorSeer enabled with order={taylorseer_order}") - - # Build ParamsModifier for each transformer - primary_modifier = ParamsModifier( - cache_config=primary_cache_config, - calibrator_config=primary_calibrator, - ) - secondary_modifier = ParamsModifier( - cache_config=secondary_cache_config, - calibrator_config=secondary_calibrator, - ) - - logger.info( - "Enabling cache-dit on Wan2.2 dual transformers with BlockAdapter: " - f"Fn={primary_cache_config.Fn_compute_blocks}, " - f"Bn={primary_cache_config.Bn_compute_blocks}, " - f"W={primary_cache_config.max_warmup_steps}, " - ) - - transformer = pipeline.transformer - transformer_2 = pipeline.transformer_2 - transformer_blocks = transformer.blocks - transformer_2_blocks = transformer_2.blocks - # Enable cache-dit using BlockAdapter for both transformers simultaneously cache_dit.enable_cache( BlockAdapter( - transformer=[transformer, transformer_2], - blocks=[transformer_blocks, transformer_2_blocks], - forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2], - params_modifiers=[primary_modifier, secondary_modifier], + transformer=[ + pipeline.transformer, + pipeline.transformer_2, + ], + blocks=[ + pipeline.transformer.blocks, + pipeline.transformer_2.blocks, + ], + forward_pattern=[ + ForwardPattern.Pattern_2, + ForwardPattern.Pattern_2, + ], + params_modifiers=[ + # high-noise transformer only have 30% steps + ParamsModifier( + cache_config=DBCacheConfig().reset( + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + ), + ), + ParamsModifier( + cache_config=DBCacheConfig().reset( + max_warmup_steps=2, + max_cached_steps=20, + ), + ), + ], has_separate_cfg=True, ), + cache_config=DBCacheConfig( + Fn_compute_blocks=cache_config.Fn_compute_blocks, + Bn_compute_blocks=cache_config.Bn_compute_blocks, + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, + residual_diff_threshold=cache_config.residual_diff_threshold, + num_inference_steps=None, + ), ) # from https://github.com/vipshop/cache-dit/pull/542 @@ -133,8 +133,8 @@ def _split_inference_steps(num_inference_steps: int) -> tuple[int, int]: Returns: A tuple of (num_high_noise_steps, num_low_noise_steps). """ - if pipeline.config.boundary_ratio is not None: - boundary_timestep = pipeline.config.boundary_ratio * pipeline.scheduler.config.num_train_timesteps + if pipeline.boundary_ratio is not None: + boundary_timestep = pipeline.boundary_ratio * pipeline.scheduler.config.num_train_timesteps else: boundary_timestep = None @@ -158,10 +158,21 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool pipeline: The Wan2.2 pipeline instance. num_inference_steps: New number of inference steps. """ + num_high_noise_steps, num_low_noise_steps = _split_inference_steps(num_inference_steps) # Refresh context for high-noise transformer if cache_config.scm_steps_mask_policy is None: - cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose) + # cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose) + cache_dit.refresh_context( + pipeline.transformer, + num_inference_steps=num_high_noise_steps, + verbose=verbose, + ) + cache_dit.refresh_context( + pipeline.transformer_2, + num_inference_steps=num_low_noise_steps, + verbose=verbose, + ) else: cache_dit.refresh_context( pipeline.transformer, @@ -175,10 +186,6 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool verbose=verbose, ) - # Refresh context for low-noise transformer - if cache_config.scm_steps_mask_policy is None: - cache_dit.refresh_context(pipeline.transformer_2, num_inference_steps=num_low_noise_steps, verbose=verbose) - else: cache_dit.refresh_context( pipeline.transformer_2, cache_config=DBCacheConfig().reset( @@ -790,7 +797,9 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool # Register custom cache-dit enablers after function definitions CUSTOM_DIT_ENABLERS.update( { - "WanPipeline": enable_cache_for_wan22, + "Wan22Pipeline": enable_cache_for_wan22, + "Wan22I2VPipeline": enable_cache_for_wan22, + "Wan22TI2VPipeline": enable_cache_for_wan22, "FluxPipeline": enable_cache_for_flux, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index f85903b9bd4..ed219560569 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -258,6 +258,7 @@ class OmniDiffusionConfig: # Cache backend configuration (NEW) cache_backend: str = "none" # "tea_cache", "deep_cache", etc. cache_config: DiffusionCacheConfig | dict[str, Any] = field(default_factory=dict) + enable_cache_dit_summary: bool = False # Distributed executor backend distributed_executor_backend: str = "mp" diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 2e00f41b97f..b747845f7fc 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -334,6 +334,7 @@ def _dummy_run(self): pil_image=dummy_image, num_inference_steps=num_inference_steps, num_outputs_per_prompt=1, + skip_cache_refresh=True, ) logger.info("dummy run to warm up the model") requests = self.pre_process_func([req]) if self.pre_process_func is not None else [req] diff --git a/vllm_omni/diffusion/request.py b/vllm_omni/diffusion/request.py index 279b3bf1dfc..b8715c20d09 100644 --- a/vllm_omni/diffusion/request.py +++ b/vllm_omni/diffusion/request.py @@ -162,6 +162,9 @@ class OmniDiffusionRequest: # debugging debug: bool = False + # cache behavior + skip_cache_refresh: bool = False + # results output: torch.Tensor | None = None diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py index 5fb0cc670c0..eb124cd1e43 100644 --- a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py @@ -20,6 +20,7 @@ from vllm.logger import init_logger from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes +from vllm_omni.diffusion.cache.cache_dit_backend import cache_summary from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.compile import regionally_compile from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig @@ -292,4 +293,8 @@ def execute_model(self, reqs: list[OmniDiffusionRequest]) -> DiffusionOutput: with record_function("pipeline_forward"): output = self.pipeline.forward(req) + # NOTE: + if self.od_config.cache_backend == "cache_dit" and self.od_config.enable_cache_dit_summary: + cache_summary(self.pipeline, details=True) + return output diff --git a/vllm_omni/diffusion/worker/npu/npu_worker.py b/vllm_omni/diffusion/worker/npu/npu_worker.py index 446c29cae4d..c4d2c60e553 100644 --- a/vllm_omni/diffusion/worker/npu/npu_worker.py +++ b/vllm_omni/diffusion/worker/npu/npu_worker.py @@ -132,7 +132,11 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi req.generator = torch.Generator(device=self.device).manual_seed(req.seed) # Refresh cache context if needed - if self.cache_backend is not None and self.cache_backend.is_enabled(): + if ( + not getattr(req, "skip_cache_refresh", False) + and self.cache_backend is not None + and self.cache_backend.is_enabled() + ): self.cache_backend.refresh(self.pipeline, req.num_inference_steps) with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): output = self.pipeline.forward(req) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 3c275147fa0..f76b0652f57 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -162,6 +162,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "vae_use_tiling": kwargs.get("vae_use_tiling", False), "cache_backend": cache_backend, "cache_config": cache_config, + "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False), "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), "enforce_eager": kwargs.get("enforce_eager", False), }, diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index c3a37e3c82e..bbd16711194 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -175,6 +175,11 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="JSON string of cache configuration (e.g., '{\"rel_l1_thresh\": 0.2}').", ) + omni_config_group.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) # VAE memory optimization parameters omni_config_group.add_argument( From 3b0355e8f9c27c0a4404ce8b9fd1a519cc31f06f Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 28 Jan 2026 09:43:13 +0000 Subject: [PATCH 3/4] fix example Signed-off-by: samithuang <285365963@qq.com> --- examples/offline_inference/text_to_video/text_to_video.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index b1d5e8b6db3..49522278f96 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -76,11 +76,6 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ring sequence parallelism.", ) - parser.add_argument( - "--enforce_eager", - action="store_true", - help="Disable torch.compile and force eager execution.", - ) return parser.parse_args() @@ -130,7 +125,6 @@ def main(): cache_backend=args.cache_backend, cache_config=cache_config, enable_cache_dit_summary=args.enable_cache_dit_summary, - enforce_eager=args.enforce_eager, enable_cpu_offload=args.enable_cpu_offload, parallel_config=parallel_config, enforce_eager=args.enforce_eager, From adfd2fea850b0bd3868a4647935abcda6a90ea9d Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 28 Jan 2026 09:55:42 +0000 Subject: [PATCH 4/4] fix runner Signed-off-by: samithuang <285365963@qq.com> --- vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py index 7cf34608223..0401655a858 100644 --- a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py @@ -285,7 +285,11 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: req.sampling_params.generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) # Refresh cache context if needed - if self.cache_backend is not None and self.cache_backend.is_enabled(): + if ( + not getattr(req, "skip_cache_refresh", False) + and self.cache_backend is not None + and self.cache_backend.is_enabled() + ): self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps) with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):