Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 69 additions & 23 deletions vllm_omni/diffusion/cache/cache_dit_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,41 +1170,85 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context


def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for GLM-Image pipeline.
def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for Flux.2-dev pipeline.

GLM-Image processes prompt and image by calling the transformer before the
denoising loop. When an input image is provided (editing mode), the cache must
be force-refreshed after the preprocessing step so stale hidden states are
discarded. Set force_refresh_step_hint = 1 for editing, None for text-to-image.
Args:
pipeline: The Flux2 pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called with a new ``num_inference_steps``
to update the cache context for the pipeline.
"""
# Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)

calibrator_config = None
calibrator = None
if cache_config.enable_taylorseer:
calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=cache_config.taylorseer_order)
logger.info(f"TaylorSeer enabled with order={cache_config.taylorseer_order}")
taylorseer_order = cache_config.taylorseer_order
calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

# Build ParamsModifier for transformer
modifier = ParamsModifier(
cache_config=db_cache_config,
calibrator_config=calibrator,
)

logger.info(
f"Enabling cache-dit on GLM-Image transformer: "
f"Enabling cache-dit on Flux transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
f"force_refresh_step_hint={db_cache_config.force_refresh_step_hint}, "
)

# Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
pipeline.transformer,
(
BlockAdapter(
transformer=pipeline.transformer,
blocks=[
pipeline.transformer.transformer_blocks,
pipeline.transformer.single_transformer_blocks,
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
params_modifiers=[modifier],
)
),
cache_config=db_cache_config,
calibrator_config=calibrator_config,
)

def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.

def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for Flux.2-dev pipeline.
Args:
pipeline: The Flux2 pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)

return refresh_cache_context


def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for GlmImage pipeline.

Args:
pipeline: The Flux2 pipeline instance.
pipeline: The GlmImage pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called with a new ``num_inference_steps``
Expand All @@ -1226,23 +1270,25 @@ def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int],
)

logger.info(
f"Enabling cache-dit on Flux transformer with BlockAdapter: "
f"Enabling cache-dit on GlmImage transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)

# Enable cache-dit using BlockAdapter for transformer
# Note: We don't use patch_functor here because it's designed for diffusers' GlmImage,
# and our vllm-omni implementation has a different forward signature.
# We use ForwardPattern.Pattern_0 because our block returns (hidden_states, encoder_hidden_states)
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
blocks=[
pipeline.transformer.transformer_blocks,
pipeline.transformer.single_transformer_blocks,
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
blocks=pipeline.transformer.transformer_blocks,
forward_pattern=ForwardPattern.Pattern_0,
params_modifiers=[modifier],
patch_functor=None,
has_separate_cfg=True,
)
),
cache_config=db_cache_config,
Expand All @@ -1252,7 +1298,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"""Refresh cache context for the transformer with new num_inference_steps.

Args:
pipeline: The Flux2 pipeline instance.
pipeline: The GlmImage pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
Expand Down
Loading