Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion docs/design/feature/teacache.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,41 @@ for prompt in tqdm(prompts, desc="Collecting data"):

# Estimate coefficients
coeffs = estimator.estimate(poly_order=4)
print(f"Estimated coefficients: {coeffs.tolist()}")
print(f"Estimated coefficients: {coeffs}")
```

Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation.
```python
from vllm_omni.diffusion.forward_context import set_forward_context
from vllm_omni.diffusion.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.config import VllmConfig
...

if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "8192"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

vllm_config = VllmConfig()
init_distributed_environment()
initialize_model_parallel()

# NOTE: you may have to pass an initialized OmniDiffusionConfig as a kwarg
# here to make current sp checks happy; if this is the case, just create one
# .from_kwargs() with the model name to get around this check for now,
# since your estimator subclass should handle the actual model configuration.
#
# This will be cleaned up in the future
with set_forward_context(vllm_config):
<create the estimator + run estimation here>
```


**Data Statistics Guide:**

| Metric | Good Range | Warning Signs |
Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/diffusion_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ The following tables show which models support each feature:
| **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **LongCat-Image** | | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **LongCat-Image-Edit** | | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **LongCat-Image** | | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **LongCat-Image-Edit** | | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **MagiHuman** | ❌ | ❌ | ❌ | ❓ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
Expand Down
118 changes: 52 additions & 66 deletions vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import Any

import numpy as np
import torch
from vllm.config import LoadConfig
from vllm.utils.torch_utils import set_default_torch_dtype
from vllm.transformers_utils.config import get_hf_file_to_dict

from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline
from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

Expand All @@ -36,6 +34,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:

def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
ctx = self.extractor_fn(module, *args, **kwargs)
# NOTE: We upcast to float32 to also handle bfloat16.
modulated_input_cpu = ctx.modulated_input.detach().float().cpu().numpy()

outputs = ctx.run_transformer_blocks()
Expand All @@ -54,102 +53,90 @@ def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]:
return list(self.current_trajectory)


class BagelAdapter:
"""Adapter for Bagel model."""
class DefaultAdapter:
"""Default adapter for standard diffusers pipelines."""

@staticmethod
def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline:
od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
od_config.model_class_name = "BagelPipeline"
model_class_name = None
uses_tf_config = True

@classmethod
def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any:
if cls.model_class_name is None:
raise ValueError("Adapter doesn't have a set class name.")

pipeline = BagelPipeline(od_config=od_config)
loader = DiffusersPipelineLoader(LoadConfig())
loader.load_weights(pipeline)
pipeline.to(device)
return pipeline
od_config = OmniDiffusionConfig.from_kwargs(
model_class_name=cls.model_class_name,
model=model_path,
dtype=dtype,
)

if cls.uses_tf_config:
# TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig
# instead of OmniDiffusion and remove the manual population here
tf_config_dict = get_hf_file_to_dict(
os.path.join("transformer", "config.json"),
od_config.model,
)
od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict)

loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config)
# load_model will handle dtypes / device placement, put in .eval() mode
return loader.load_model(od_config=od_config, load_device=device)

@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.bagel, "Bagel"
return pipeline.transformer, pipeline.transformer.__class__.__name__

@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)


class StableAudioAdapter:
"""Adapter for Stable Audio Open 1.0 coefficient estimation."""

@staticmethod
def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.float16) -> Any:
od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)

# Strictly necessary because we bypass loader.load_model()
with set_default_torch_dtype(dtype):
pipeline = StableAudioPipeline(od_config=od_config)
class BagelAdapter(DefaultAdapter):
"""Adapter for Bagel model."""

loader = DiffusersPipelineLoader(LoadConfig())
loader.load_weights(pipeline)
pipeline.to(device)
return pipeline
model_class_name = "BagelPipeline"
# Skip the hack for loading the tf model config,
# because bagel doesn't use it.
uses_tf_config = False

@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.transformer, "StableAudioDiTModel"
return pipeline.bagel, "Bagel"

@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)


class Flux2Adapter:
class Flux2Adapter(DefaultAdapter):
"""Adapter for Flux2 model coefficient estimation."""

@staticmethod
def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> Flux2Pipeline:
"""Load Flux2 pipeline for coefficient estimation."""
od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
od_config.model_class_name = "Flux2Pipeline"

pipeline = Flux2Pipeline(od_config=od_config)
loader = DiffusersPipelineLoader(LoadConfig())
loader.load_weights(pipeline)
pipeline.to(device)
return pipeline
model_class_name = "Flux2Pipeline"

@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.transformer, pipeline.transformer.__class__.__name__

@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)
class LongCatAdapter(DefaultAdapter):
"""Adapter for LongCat Image - NOTE: currently this model needs the vLLM
context to be correctly configured to actually run the estimation, since it
uses vLLM norm layers etc.
"""

model_class_name = "LongCatImagePipeline"

class DefaultAdapter:
"""Default adapter for standard diffusers pipelines."""

@staticmethod
def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
raise NotImplementedError("DefaultAdapter.load_pipeline not implemented")

@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.transformer, pipeline.transformer.__class__.__name__
class StableAudioAdapter(DefaultAdapter):
"""Adapter for Stable Audio Open 1.0 coefficient estimation."""

@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)
model_class_name = "StableAudioPipeline"


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be wrapped with set_default_torch_dtype(od_config.dtype) like BagelAdapter.load_pipeline was updated to do above?

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had actually added a set_default_torch_dtype around the call to the load pipeline on the adapter instead of just putting it around the one line 🙂 the better way to do this is

        loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config)
        return loader.load_model(od_config=od_config, load_device=device)

because load_model will handle the device placement, put the model in eval mode, and handle the dtypes from the diffusion config. Updated both to avoid managing default dtypes manually and made sure the bagel one still runs

_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
"StableAudio": StableAudioAdapter,
"Flux2": Flux2Adapter,
"LongCat": LongCatAdapter,
}

_EPSILON = 1e-6
Expand Down Expand Up @@ -196,7 +183,6 @@ def __init__(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
# Add validation here ⬇️
if model_type not in _MODEL_ADAPTERS:
available_types = list(_MODEL_ADAPTERS.keys())
raise ValueError(
Expand All @@ -205,7 +191,7 @@ def __init__(
f"To add support for a new model, add an entry to _MODEL_ADAPTERS."
)

adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter)
adapter = _MODEL_ADAPTERS[model_type]
self.pipeline = adapter.load_pipeline(model_path, device, dtype)
self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
self.hook = DataCollectionHook(self.transformer_type)
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/diffusion/cache/teacache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
3.20000000e00,
-2.00000000e-02,
],
# LongCat Image transformer coefficients
"LongCatImageTransformer2DModel": [652.5980, -424.1615, 84.5526, -4.5923, 0.1694],
}


Expand Down
103 changes: 103 additions & 0 deletions vllm_omni/diffusion/cache/teacache/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@

import torch
import torch.nn as nn
from vllm.logger import init_logger

from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.platforms import current_omni_platform

logger = init_logger(__name__)


@dataclass
class CacheContext:
Expand Down Expand Up @@ -723,6 +726,105 @@ def postprocess(h):
)


def extract_longcat_context(
module: nn.Module, # LongCatImageTransformer2DModel
hidden_states,
timestep,
guidance,
encoder_hidden_states,
txt_ids,
img_ids,
**kwargs,
) -> CacheContext:
"""Extract the cache context for LongCat Image.

Similar to other extractors, this is currently the only code needed
for TeaCache support for LongCat image, and encapsulates preprocessing,
modulated input extraction, transformer execution, and postprocessing
logic.

Args & kawrgs are identical to the inputs to LongCat Image's forward.

Returns:
CacheContext with all information needed for generic caching
"""
# TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward
from diffusers.models.modeling_outputs import Transformer2DModelOutput
Comment thread
alex-jw-brooks marked this conversation as resolved.

# 1. Model specific preprocessing
Comment thread
alex-jw-brooks marked this conversation as resolved.
fwd_context = get_forward_context()
sp_size = module.parallel_config.sequence_parallel_size
if sp_size is not None and sp_size > 1:
# NOTE: For now, we set this to False on the forward context
# to be consistent with LongCat Image's current behavior when
# TeaCache is enabled. We do not need to reset it in post process
# since we should never split text embed in sp for this model.
fwd_context.split_text_embed_in_sp = False

hidden_states = module.x_embedder(hidden_states)
Comment on lines +756 to +764
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve sequence-parallel sharding in LongCat extractor

In the SP case (sequence_parallel_size > 1), this code enables SP in the forward context but does not replicate the required LongCat preprocessing (chunking image hidden_states and RoPE by rank, as done in LongCatImageTransformer2DModel.forward). As a result, SP attention paths run on unsharded layouts, which yields invalid coefficient-collection behavior and can break distributed estimation runs.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful, but I think there are larger underlying problems in SP for this model at the moment (see #1556). I will investigate the fix for that as well, but see the same error with & without TeaCache at the moment, so open to any direction for how to handle it on this PR


timestep = timestep.to(hidden_states.dtype) * 1000

temb = module.time_embed(timestep, hidden_states.dtype)
encoder_hidden_states = module.context_embedder(encoder_hidden_states)

# Compute RoPE embeddings via rope_preparer module
# _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3)
# txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention
txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids)

# Reconstruct image_rotary_emb with chunked values
# Final shape: (txt_seq_len + img_seq_len // SP, head_dim)
image_rotary_emb = (
torch.cat([txt_cos, img_cos], dim=0),
torch.cat([txt_sin, img_sin], dim=0),
)

# 2. Extract the modulated output from the first mm-DiT block
first_block = module.transformer_blocks[0]
img_modulated = first_block.norm1(hidden_states, emb=temb)[0]

# 3. Define the transformer execution
def run_transformer_blocks():
"""Execute all Longcat transformer blocks."""
h = hidden_states
e = encoder_hidden_states
for block in module.transformer_blocks:
e, h = block(
Comment thread
alex-jw-brooks marked this conversation as resolved.
hidden_states=h,
encoder_hidden_states=e,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

for block in module.single_transformer_blocks:
e, h = block(
hidden_states=h,
encoder_hidden_states=e,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# Hook expects hidden states to be first
return (h, e)

# 4. Postprocessing
def postprocess(h):
"""Apply Longcat-specific output postprocessing."""
h = module.norm_out(h, temb)
output = module.proj_out(h)
return Transformer2DModelOutput(sample=output)

# 5. Return the CacheContext
return CacheContext(
modulated_input=img_modulated,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
)


def extract_stable_audio_context(
module: nn.Module,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -980,6 +1082,7 @@ def postprocess(h):
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
"Flux2Transformer2DModel": extract_flux2_context,
"LongCatImageTransformer2DModel": extract_longcat_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
Expand Down
Loading