Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def parse_args() -> argparse.Namespace:
"Default: None (no cache acceleration)."
),
)
parser.add_argument(
"--enable-cache-dit-summary",
action="store_true",
help="Enable cache-dit summary logging after diffusion forward passes.",
)
parser.add_argument(
"--ulysses_degree",
type=int,
Expand Down Expand Up @@ -166,6 +171,7 @@ def main():
vae_use_tiling=vae_use_tiling,
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_cache_dit_summary=args.enable_cache_dit_summary,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
enable_cpu_offload=args.enable_cpu_offload,
Expand Down
47 changes: 42 additions & 5 deletions examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,29 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)."
)
parser.add_argument(
"--cache_backend",
type=str,
default=None,
choices=["cache_dit"],
help=(
"Cache backend to use for acceleration. "
"Options: 'cache_dit' (DBCache + SCM + TaylorSeer). "
"Default: None (no cache acceleration)."
),
)
parser.add_argument(
"--enable-cache-dit-summary",
action="store_true",
help="Enable cache-dit summary logging after diffusion forward passes.",
)
parser.add_argument("--output", type=str, default="wan22_output.mp4", help="Path to save the video (mp4).")
parser.add_argument("--fps", type=int, default=24, help="Frames per second for the output video.")
parser.add_argument(
"--enforce_eager",
action="store_true",
help="Disable torch.compile and force eager execution.",
)
parser.add_argument(
"--enable-cpu-offload",
action="store_true",
Expand All @@ -55,11 +76,6 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of GPUs used for ring sequence parallelism.",
)
parser.add_argument(
"--enforce_eager",
action="store_true",
help="Disable torch.compile and force eager execution.",
)
return parser.parse_args()


Expand All @@ -72,6 +88,24 @@ def main():
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()

# Wan2.2 cache-dit tuning (from cache-dit examples and cache_alignment).
cache_config = None
if args.cache_backend == "cache_dit":
cache_config = {
# DBCache parameters [cache-dit only]
"Fn_compute_blocks": 1, # Optimized for single-transformer models
"Bn_compute_blocks": 0, # Number of backward compute blocks
"max_warmup_steps": 4, # Maximum warmup steps (works for few-step models)
"max_cached_steps": 20,
"residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching
"max_continuous_cached_steps": 3, # Limit to prevent precision degradation
# TaylorSeer parameters [cache-dit only]
"enable_taylorseer": False, # Disabled by default (not suitable for few-step models)
"taylorseer_order": 1, # TaylorSeer polynomial order
# SCM (Step Computation Masking) parameters [cache-dit only]
"scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
"scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static"
}
# Configure parallel settings (only SP is supported for Wan)
# Note: cfg_parallel and tensor_parallel are not implemented for Wan models
parallel_config = DiffusionParallelConfig(
Expand All @@ -88,6 +122,9 @@ def main():
vae_use_tiling=vae_use_tiling,
boundary_ratio=args.boundary_ratio,
flow_shift=args.flow_shift,
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_cache_dit_summary=args.enable_cache_dit_summary,
enable_cpu_offload=args.enable_cpu_offload,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
Expand Down
113 changes: 61 additions & 52 deletions vllm_omni/diffusion/cache/cache_dit_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
logger = init_logger(__name__)


# Small helper to centralize cache-dit summaries.
def cache_summary(pipeline: Any, details: bool = True) -> None:
cache_dit.summary(pipeline.transformer, details=details)
if hasattr(pipeline, "transformer_2"):
cache_dit.summary(pipeline.transformer_2, details=details)


# Registry of custom cache-dit enablers for specific models
# Maps pipeline names to their cache-dit enablement functions
# Models in this registry require custom handling (e.g., dual-transformer architectures)
Expand Down Expand Up @@ -71,53 +78,46 @@ def enable_cache_for_wan22(pipeline: Any, cache_config: Any) -> Callable[[int],
A refresh function that can be called to update cache context with new num_inference_steps.
"""

# Build DBCacheConfig for primary transformer
primary_cache_config = _build_db_cache_config(cache_config)

# FIXME: secondary cache shares the same config with primary cache for now,
# but we should support different config for secondary transformer in the future
# Build DBCacheConfig for secondary transformer (can use same or different config)
secondary_cache_config = _build_db_cache_config(cache_config)

# Build calibrator configs if TaylorSeer is enabled
primary_calibrator = None
secondary_calibrator = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
primary_calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
secondary_calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

# Build ParamsModifier for each transformer
primary_modifier = ParamsModifier(
cache_config=primary_cache_config,
calibrator_config=primary_calibrator,
)
secondary_modifier = ParamsModifier(
cache_config=secondary_cache_config,
calibrator_config=secondary_calibrator,
)

logger.info(
"Enabling cache-dit on Wan2.2 dual transformers with BlockAdapter: "
f"Fn={primary_cache_config.Fn_compute_blocks}, "
f"Bn={primary_cache_config.Bn_compute_blocks}, "
f"W={primary_cache_config.max_warmup_steps}, "
)

transformer = pipeline.transformer
transformer_2 = pipeline.transformer_2
transformer_blocks = transformer.blocks
transformer_2_blocks = transformer_2.blocks
# Enable cache-dit using BlockAdapter for both transformers simultaneously
cache_dit.enable_cache(
BlockAdapter(
transformer=[transformer, transformer_2],
blocks=[transformer_blocks, transformer_2_blocks],
forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],
params_modifiers=[primary_modifier, secondary_modifier],
transformer=[
pipeline.transformer,
pipeline.transformer_2,
],
blocks=[
pipeline.transformer.blocks,
pipeline.transformer_2.blocks,
],
forward_pattern=[
ForwardPattern.Pattern_2,
ForwardPattern.Pattern_2,
],
params_modifiers=[
# high-noise transformer only have 30% steps
ParamsModifier(
cache_config=DBCacheConfig().reset(
max_warmup_steps=cache_config.max_warmup_steps,
max_cached_steps=cache_config.max_cached_steps,
),
),
ParamsModifier(
cache_config=DBCacheConfig().reset(
max_warmup_steps=2,
max_cached_steps=20,
),
),
],
has_separate_cfg=True,
),
cache_config=DBCacheConfig(
Fn_compute_blocks=cache_config.Fn_compute_blocks,
Bn_compute_blocks=cache_config.Bn_compute_blocks,
max_warmup_steps=cache_config.max_warmup_steps,
max_cached_steps=cache_config.max_cached_steps,
max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
residual_diff_threshold=cache_config.residual_diff_threshold,
num_inference_steps=None,
),
)

# from https://github.com/vipshop/cache-dit/pull/542
Expand All @@ -133,8 +133,8 @@ def _split_inference_steps(num_inference_steps: int) -> tuple[int, int]:
Returns:
A tuple of (num_high_noise_steps, num_low_noise_steps).
"""
if pipeline.config.boundary_ratio is not None:
boundary_timestep = pipeline.config.boundary_ratio * pipeline.scheduler.config.num_train_timesteps
if pipeline.boundary_ratio is not None:
boundary_timestep = pipeline.boundary_ratio * pipeline.scheduler.config.num_train_timesteps
else:
boundary_timestep = None

Expand All @@ -158,10 +158,21 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
pipeline: The Wan2.2 pipeline instance.
num_inference_steps: New number of inference steps.
"""

num_high_noise_steps, num_low_noise_steps = _split_inference_steps(num_inference_steps)
# Refresh context for high-noise transformer
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose)
# cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose)
cache_dit.refresh_context(
pipeline.transformer,
num_inference_steps=num_high_noise_steps,
verbose=verbose,
)
cache_dit.refresh_context(
pipeline.transformer_2,
num_inference_steps=num_low_noise_steps,
verbose=verbose,
)
else:
cache_dit.refresh_context(
pipeline.transformer,
Expand All @@ -175,10 +186,6 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
verbose=verbose,
)

# Refresh context for low-noise transformer
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer_2, num_inference_steps=num_low_noise_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer_2,
cache_config=DBCacheConfig().reset(
Expand Down Expand Up @@ -790,8 +797,10 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
# Register custom cache-dit enablers after function definitions
CUSTOM_DIT_ENABLERS.update(
{
"WanPipeline": enable_cache_for_wan22,
# "FluxPipeline": enable_cache_for_flux,
"Wan22Pipeline": enable_cache_for_wan22,
"Wan22I2VPipeline": enable_cache_for_wan22,
"Wan22TI2VPipeline": enable_cache_for_wan22,
"FluxPipeline": enable_cache_for_flux,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class OmniDiffusionConfig:
# Cache backend configuration (NEW)
cache_backend: str = "none" # "tea_cache", "deep_cache", etc.
cache_config: DiffusionCacheConfig | dict[str, Any] = field(default_factory=dict)
enable_cache_dit_summary: bool = False

# Distributed executor backend
distributed_executor_backend: str = "mp"
Expand Down
11 changes: 10 additions & 1 deletion vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.logger import init_logger
from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes

from vllm_omni.diffusion.cache.cache_dit_backend import cache_summary
from vllm_omni.diffusion.cache.selector import get_cache_backend
from vllm_omni.diffusion.compile import regionally_compile
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
Expand Down Expand Up @@ -284,11 +285,19 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput:
req.sampling_params.generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed)

# Refresh cache context if needed
if self.cache_backend is not None and self.cache_backend.is_enabled():
if (
not getattr(req, "skip_cache_refresh", False)
and self.cache_backend is not None
and self.cache_backend.is_enabled()
):
self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps)

with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
with record_function("pipeline_forward"):
output = self.pipeline.forward(req)

# NOTE:
if self.od_config.cache_backend == "cache_dit" and self.od_config.enable_cache_dit_summary:
cache_summary(self.pipeline, details=True)

return output
6 changes: 5 additions & 1 deletion vllm_omni/diffusion/worker/npu/npu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi
req.generator = torch.Generator(device=self.device).manual_seed(req.seed)

# Refresh cache context if needed
if self.cache_backend is not None and self.cache_backend.is_enabled():
if (
not getattr(req, "skip_cache_refresh", False)
and self.cache_backend is not None
and self.cache_backend.is_enabled()
):
self.cache_backend.refresh(self.pipeline, req.num_inference_steps)
with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
output = self.pipeline.forward(req)
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st
"vae_use_tiling": kwargs.get("vae_use_tiling", False),
"cache_backend": cache_backend,
"cache_config": cache_config,
"enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False),
"enable_cpu_offload": kwargs.get("enable_cpu_offload", False),
"enforce_eager": kwargs.get("enforce_eager", False),
},
Expand Down
5 changes: 5 additions & 0 deletions vllm_omni/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
default=None,
help="JSON string of cache configuration (e.g., '{\"rel_l1_thresh\": 0.2}').",
)
omni_config_group.add_argument(
"--enable-cache-dit-summary",
action="store_true",
help="Enable cache-dit summary logging after diffusion forward passes.",
)

# VAE memory optimization parameters
omni_config_group.add_argument(
Expand Down