Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions vllm_omni/diffusion/cache/cache_dit_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,77 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context


def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for Stable Audio Open pipeline.

Args:
pipeline: The StableAudioPipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.

Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
db_cache_config = _build_db_cache_config(cache_config)

calibrator_config = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

# StableAudio is officially registered in CacheDiT as Pattern_3:
# https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562
#
# Pattern_3 is required because StableAudioDiT uses cross-attention
# with static encoder_hidden_states that do not change inside the
# transformer block loop.
cache_dit.enable_cache(
Comment thread
akshatvishu marked this conversation as resolved.
BlockAdapter(
transformer=pipeline.transformer,
blocks=pipeline.transformer.transformer_blocks,
forward_pattern=ForwardPattern.Pattern_3,
params_modifiers=[
ParamsModifier(
cache_config=db_cache_config,
calibrator_config=calibrator_config,
)
],
),
cache_config=db_cache_config,
)

def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.

Args:
Comment thread
akshatvishu marked this conversation as resolved.
pipeline: The StableAudioPipeline instance.
num_inference_steps: New number of inference steps.
verbose: Whether to log refresh operations.
"""
# Bypass SCM for step counts that don't support predefined masks (e.g., vLLM's 1-step dummy run)
scm_supported_steps = num_inference_steps >= 8 or num_inference_steps in (4, 6)

if cache_config.scm_steps_mask_policy is None or not scm_supported_steps:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
updated_scm_config = DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
)

cache_dit.refresh_context(
pipeline.transformer,
cache_config=updated_scm_config,
verbose=verbose,
)

return refresh_cache_context


def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for StableDiffusion3Pipeline.

Expand Down Expand Up @@ -1212,6 +1283,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"Flux2KleinPipeline": enable_cache_for_flux2_klein,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
"StableAudioPipeline": enable_cache_for_stable_audio_open,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
"LTX2Pipeline": enable_cache_for_ltx2,
"LTX2ImageToVideoPipeline": enable_cache_for_ltx2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ class StableAudioDiTModel(nn.Module):
- Output: [B, out_channels, L]
"""

_repeated_blocks = ["StableAudioDiTBlock"]

def __init__(
self,
od_config: OmniDiffusionConfig | None = None,
Expand Down
Loading