diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index a5055a0688e..e9f79da4f3b 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -464,6 +464,77 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Stable Audio Open pipeline. + + Args: + pipeline: The StableAudioPipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # StableAudio is officially registered in CacheDiT as Pattern_3: + # https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562 + # + # Pattern_3 is required because StableAudioDiT uses cross-attention + # with static encoder_hidden_states that do not change inside the + # transformer block loop. + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=pipeline.transformer.transformer_blocks, + forward_pattern=ForwardPattern.Pattern_3, + params_modifiers=[ + ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + ], + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The StableAudioPipeline instance. + num_inference_steps: New number of inference steps. + verbose: Whether to log refresh operations. + """ + # Bypass SCM for step counts that don't support predefined masks (e.g., vLLM's 1-step dummy run) + scm_supported_steps = num_inference_steps >= 8 or num_inference_steps in (4, 6) + + if cache_config.scm_steps_mask_policy is None or not scm_supported_steps: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + updated_scm_config = DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ) + + cache_dit.refresh_context( + pipeline.transformer, + cache_config=updated_scm_config, + verbose=verbose, + ) + + return refresh_cache_context + + def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. @@ -1212,6 +1283,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "Flux2KleinPipeline": enable_cache_for_flux2_klein, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, + "StableAudioPipeline": enable_cache_for_stable_audio_open, "StableDiffusion3Pipeline": enable_cache_for_sd3, "LTX2Pipeline": enable_cache_for_ltx2, "LTX2ImageToVideoPipeline": enable_cache_for_ltx2, diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py index 22d56ac1fd1..4a4892673f1 100644 --- a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py +++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py @@ -375,6 +375,8 @@ class StableAudioDiTModel(nn.Module): - Output: [B, out_channels, L] """ + _repeated_blocks = ["StableAudioDiTBlock"] + def __init__( self, od_config: OmniDiffusionConfig | None = None,