Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user_guide/diffusion_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ The following tables show which models support each feature:
| **FLUX.1-dev** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-dev** | | | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-dev** | | | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
Expand Down
105 changes: 104 additions & 1 deletion tests/diffusion/cache/test_teacache_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch

from tests.utils import hardware_test
from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_klein_context
from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_context, extract_flux2_klein_context
from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
Flux2Transformer2DModel,
)
Expand Down Expand Up @@ -174,3 +174,106 @@ def test_invalid_module_raises_error(self):
img_ids=torch.randint(0, 64, (1, 1024, 4)),
txt_ids=torch.randint(0, 64, (1, 512, 4)),
)


class TestFlux2Extractor(BaseExtractorTest):
"""Test extract_flux2_context function."""

def get_extractor(self):
return extract_flux2_context

@pytest.fixture
def flux2_module(self):
"""Create a minimal Flux2Transformer2DModel for testing."""
from vllm_omni.diffusion.models.flux2.flux2_transformer import Flux2Transformer2DModel

model = Flux2Transformer2DModel(
num_layers=2,
num_single_layers=2,
num_attention_heads=48,
attention_head_dim=128,
joint_attention_dim=15360,
)
return model

def get_module(self, flux2_module):
return flux2_module

@pytest.fixture
def sample_inputs(self):
"""Create sample input tensors for Flux2.

Note: hidden_states uses in_channels=128 (default for Flux2),
not inner_dim=6144. The x_embedder projects from 128 -> 6144.
encoder_hidden_states uses joint_attention_dim=15360 (model default),
which then gets projected to inner_dim=6144 by context_embedder.
"""
batch_size = 1
img_seq_len = 1024
txt_seq_len = 512
in_channels = 128 # Model default in_channels
txt_dim = 15360 # Model default joint_attention_dim

return {
"hidden_states": torch.randn(batch_size, img_seq_len, in_channels),
"encoder_hidden_states": torch.randn(batch_size, txt_seq_len, txt_dim),
"timestep": torch.tensor([500]),
"img_ids": torch.randint(0, 64, (batch_size, img_seq_len, 4)),
"txt_ids": torch.randint(0, 64, (batch_size, txt_seq_len, 4)),
"guidance": torch.tensor([3.5]),
}

def get_sample_inputs(self, sample_inputs):
return sample_inputs

@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_modulated_input_shape(self, flux2_module, sample_inputs):
"""Test that modulated_input has correct shape matching the model's inner_dim.

Note: After x_embedder projection, hidden_states are projected from
in_channels (128) to inner_dim (6144), so modulated_input should match
the projected shape, not the input shape.
"""
context = extract_flux2_klein_context(flux2_module, **sample_inputs)

batch_size, img_seq_len, _ = sample_inputs["hidden_states"].shape
inner_dim = flux2_module.inner_dim
assert context.modulated_input.shape == (batch_size, img_seq_len, inner_dim)

@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_run_transformer_blocks_callable(self, flux2_module, sample_inputs):
"""Test that run_transformer_blocks is callable."""
context = extract_flux2_context(flux2_module, **sample_inputs)
assert callable(context.run_transformer_blocks)

@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_postprocess_callable(self, flux2_module, sample_inputs):
"""Test that postprocess is callable."""
context = extract_flux2_context(flux2_module, **sample_inputs)
assert callable(context.postprocess)

def test_without_guidance(self, flux2_module, sample_inputs):
"""Test context extraction works without guidance (no CFG)."""
inputs = sample_inputs.copy()
inputs["guidance"] = None

context = extract_flux2_context(flux2_module, **inputs)

assert context is not None
assert context.temb is not None

@pytest.mark.cpu
def test_invalid_module_raises_error(self):
"""Test that invalid module without transformer_blocks raises ValueError."""
invalid_module = Mock()
invalid_module.transformer_blocks = []

with pytest.raises(ValueError, match="Module must have transformer_blocks"):
extract_flux2_context(
invalid_module,
hidden_states=torch.randn(1, 1024, 6144),
encoder_hidden_states=torch.randn(1, 512, 15360),
timestep=torch.tensor([500]),
img_ids=torch.randint(0, 64, (1, 1024, 4)),
txt_ids=torch.randint(0, 64, (1, 512, 4)),
)
27 changes: 27 additions & 0 deletions vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline
from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
Expand Down Expand Up @@ -103,6 +104,31 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry.register_hook(hook._HOOK_NAME, hook)


class Flux2Adapter:
"""Adapter for Flux2 model coefficient estimation."""

@staticmethod
def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> Flux2Pipeline:
"""Load Flux2 pipeline for coefficient estimation."""
od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
od_config.model_class_name = "Flux2Pipeline"

pipeline = Flux2Pipeline(od_config=od_config)
loader = DiffusersPipelineLoader(LoadConfig())
loader.load_weights(pipeline)
pipeline.to(device)
return pipeline

@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.transformer, pipeline.transformer.__class__.__name__

@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)


class DefaultAdapter:
"""Default adapter for standard diffusers pipelines."""

Expand All @@ -123,6 +149,7 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
"StableAudio": StableAudioAdapter,
"Flux2": Flux2Adapter,
}

_EPSILON = 1e-6
Expand Down
9 changes: 9 additions & 0 deletions vllm_omni/diffusion/cache/teacache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@
-1.04182570e01,
6.78098549e-01,
],
# Flux2 transformer coefficients
# Copied from Qwen-Image, need to be tuned specifically for Flux2 in future
"Flux2Transformer2DModel": [
-4.50000000e02,
2.80000000e02,
-4.50000000e01,
3.20000000e00,
-2.00000000e-02,
],
}


Expand Down
140 changes: 140 additions & 0 deletions vllm_omni/diffusion/cache/teacache/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn as nn

from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.platforms import current_omni_platform


@dataclass
Expand Down Expand Up @@ -827,6 +828,144 @@ def postprocess(h: torch.Tensor) -> Any:
)


def extract_flux2_context(
module: nn.Module,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
**kwargs: Any,
) -> CacheContext:
"""
Extract cache context for Flux2Transformer2DModel.

This is the ONLY Flux2-specific code needed for TeaCache support.
It encapsulates preprocessing, modulated input extraction, transformer execution,
and postprocessing logic.

Args:
module: Flux2Transformer2DModel instance
hidden_states: Input hidden states tensor
encoder_hidden_states: Text encoder outputs
timestep: Current diffusion timestep
img_ids: Image inputs for position embedding
txt_ids: Text inputs for position embedding
guidance: Optional guidance scale for CFG
joint_attention_kwargs: Additional attention arguments
return_dict: Whether to return a Transformer2DModelOutput instead of a plain tensor
**kwargs: Additional keyword arguments ignored by this extractor

Returns:
CacheContext with all information needed for generic caching
"""

from diffusers.models.modeling_outputs import Transformer2DModelOutput

if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
raise ValueError("Module must have transformer_blocks")

# ============================================================================
# PREPROCESSING (Flux2-specific)
# ============================================================================
num_txt_tokens = encoder_hidden_states.shape[1]

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000

temb = module.time_guidance_embed(timestep, guidance)

double_stream_mod_img = module.double_stream_modulation_img(temb)
double_stream_mod_txt = module.double_stream_modulation_txt(temb)
single_stream_mod = module.single_stream_modulation(temb)[0]

hidden_states = module.x_embedder(hidden_states)
encoder_hidden_states = module.context_embedder(encoder_hidden_states)

if img_ids.ndim == 3:
img_ids = img_ids[0]
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]

if current_omni_platform.is_npu():
freqs_cos_image, freqs_sin_image = module.pos_embed(img_ids.cpu())
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
freqs_cos_text, freqs_sin_text = module.pos_embed(txt_ids.cpu())
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
else:
image_rotary_emb = module.pos_embed(img_ids)
text_rotary_emb = module.pos_embed(txt_ids)
concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
)

# ============================================================================
# EXTRACT MODULATED INPUT (for cache decision)
# ============================================================================
block = module.transformer_blocks[0]
(shift_msa, scale_msa, gate_msa), _ = double_stream_mod_img
modulated_input = block.norm1(hidden_states)
modulated_input = (1 + scale_msa) * modulated_input + shift_msa

# ============================================================================
# DEFINE TRANSFORMER EXECUTION (Flux2-specific)
# ============================================================================
def run_transformer_blocks():
"""Execute all Flux2 transformer blocks."""
h = hidden_states
e = encoder_hidden_states

for transformer_block in module.transformer_blocks:
e, h = transformer_block(
hidden_states=h,
encoder_hidden_states=e,
temb_mod_params_img=double_stream_mod_img,
temb_mod_params_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
h = torch.cat([e, h], dim=1)

for single_transformer_block in module.single_transformer_blocks:
h = single_transformer_block(
hidden_states=h,
encoder_hidden_states=None,
temb_mod_params=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

h = h[:, num_txt_tokens:, ...]
return (h,)

# ============================================================================
# DEFINE POSTPROCESSING
# ============================================================================
def postprocess(h):
h = module.norm_out(h, temb)
output = module.proj_out(h)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

# ============================================================================
# RETURN CONTEXT
# ============================================================================
return CacheContext(
modulated_input=modulated_input,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
)


# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
Expand All @@ -839,6 +978,7 @@ def postprocess(h: torch.Tensor) -> Any:
"ZImageTransformer2DModel": extract_zimage_context,
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
"Flux2Transformer2DModel": extract_flux2_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
Expand Down
Loading