Skip to content
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ th {
|`GlmImageForConditionalGeneration` | GLM-Image | `zai-org/GLM-Image` |
|`NextStep11Pipeline` | NextStep-1.1 | `stepfun-ai/NextStep-1.1` |
|`MiMoAudioForConditionalGeneration` | MiMo-Audio-7B-Instruct | `XiaomiMiMo/MiMo-Audio-7B-Instruct` |
|`Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` |


## List of Supported Models for NPU
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide/diffusion/parallelism_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The following table shows which models are currently supported by parallelism me
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ |
| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ |

!!! note "TP Limitations for Diffusion Models"
We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ The following table shows which models are currently supported by each accelerat
| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| **NextStep-1.1** | `stepfun-ai/NextStep-1.1` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |

### VideoGen

Expand Down
79 changes: 79 additions & 0 deletions vllm_omni/diffusion/cache/cache_dit_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,84 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context


def enable_cache_for_flux2_klein(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for FLUX.2-klein-4B pipeline.

Args:
pipeline: The FLUX.2-klein-4B pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called with a new ``num_inference_steps``
to update the cache context for the pipeline.
"""
# Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)
# The Fn_compute_blocks = 2 override is the most important decision here,
# and the rationale (quality degradation at Fn=1) only lives in flux2_klein.
db_cache_config.Fn_compute_blocks = 2

calibrator = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

# Build ParamsModifier for transformer
modifier = ParamsModifier(
cache_config=db_cache_config,
calibrator_config=calibrator,
)

logger.info(
f"Enabling cache-dit on Flux2-Klein transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)

# Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
blocks=[
pipeline.transformer.transformer_blocks,
pipeline.transformer.single_transformer_blocks,
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
params_modifiers=[modifier],
)
),
cache_config=db_cache_config,
)

def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.

Args:
pipeline: The FLUX.2-klein-4B pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
Fn_compute_blocks=db_cache_config.Fn_compute_blocks,
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)

return refresh_cache_context


def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for StableDiffusion3Pipeline.

Expand Down Expand Up @@ -859,6 +937,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"Wan22I2VPipeline": enable_cache_for_wan22,
"Wan22TI2VPipeline": enable_cache_for_wan22,
"FluxPipeline": enable_cache_for_flux,
"Flux2KleinPipeline": enable_cache_for_flux2_klein,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
Expand Down
17 changes: 17 additions & 0 deletions vllm_omni/diffusion/models/flux2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Flux2 diffusion model components."""

from vllm_omni.diffusion.models.flux2.flux2_transformer import (
Flux2Transformer2DModel,
)
from vllm_omni.diffusion.models.flux2.pipeline_flux2 import (
Flux2Pipeline,
get_flux2_post_process_func,
)

__all__ = [
"Flux2Pipeline",
"Flux2Transformer2DModel",
"get_flux2_post_process_func",
]
Loading
Loading