diff --git a/docs/design/feature/teacache.md b/docs/design/feature/teacache.md index 9fa315cee77..8577cff1f05 100644 --- a/docs/design/feature/teacache.md +++ b/docs/design/feature/teacache.md @@ -326,9 +326,41 @@ for prompt in tqdm(prompts, desc="Collecting data"): # Estimate coefficients coeffs = estimator.estimate(poly_order=4) -print(f"Estimated coefficients: {coeffs.tolist()}") +print(f"Estimated coefficients: {coeffs}") ``` +Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation. +```python +from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) +from vllm.config import VllmConfig +... + +if __name__ == "__main__": + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8192" + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + vllm_config = VllmConfig() + init_distributed_environment() + initialize_model_parallel() + + # NOTE: you may have to pass an initialized OmniDiffusionConfig as a kwarg + # here to make current sp checks happy; if this is the case, just create one + # .from_kwargs() with the model name to get around this check for now, + # since your estimator subclass should handle the actual model configuration. + # + # This will be cleaned up in the future + with set_forward_context(vllm_config): + +``` + + **Data Statistics Guide:** | Metric | Good Range | Warning Signs | diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 4e7003cce37..d70fdd9df7e 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -115,8 +115,8 @@ The following tables show which models support each feature: | **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| **LongCat-Image-Edit** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **MagiHuman** | ❌ | ❌ | ❌ | ❓ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py index baec21c2762..38c805c28db 100644 --- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py +++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py @@ -1,20 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import Any import numpy as np import torch from vllm.config import LoadConfig -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.transformers_utils.config import get_hf_file_to_dict from vllm_omni.diffusion.cache.teacache.extractors import get_extractor -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.hooks import HookRegistry, ModelHook from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline -from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline -from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -36,6 +34,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: ctx = self.extractor_fn(module, *args, **kwargs) + # NOTE: We upcast to float32 to also handle bfloat16. modulated_input_cpu = ctx.modulated_input.detach().float().cpu().numpy() outputs = ctx.run_transformer_blocks() @@ -54,23 +53,39 @@ def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]: return list(self.current_trajectory) -class BagelAdapter: - """Adapter for Bagel model.""" +class DefaultAdapter: + """Default adapter for standard diffusers pipelines.""" - @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline: - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - od_config.model_class_name = "BagelPipeline" + model_class_name = None + uses_tf_config = True + + @classmethod + def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any: + if cls.model_class_name is None: + raise ValueError("Adapter doesn't have a set class name.") - pipeline = BagelPipeline(od_config=od_config) - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline + od_config = OmniDiffusionConfig.from_kwargs( + model_class_name=cls.model_class_name, + model=model_path, + dtype=dtype, + ) + + if cls.uses_tf_config: + # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig + # instead of OmniDiffusion and remove the manual population here + tf_config_dict = get_hf_file_to_dict( + os.path.join("transformer", "config.json"), + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + + loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config) + # load_model will handle dtypes / device placement, put in .eval() mode + return loader.load_model(od_config=od_config, load_device=device) @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.bagel, "Bagel" + return pipeline.transformer, pipeline.transformer.__class__.__name__ @staticmethod def install_hook(transformer: Any, hook: DataCollectionHook) -> None: @@ -78,25 +93,17 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: registry.register_hook(hook._HOOK_NAME, hook) -class StableAudioAdapter: - """Adapter for Stable Audio Open 1.0 coefficient estimation.""" - - @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.float16) -> Any: - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - - # Strictly necessary because we bypass loader.load_model() - with set_default_torch_dtype(dtype): - pipeline = StableAudioPipeline(od_config=od_config) +class BagelAdapter(DefaultAdapter): + """Adapter for Bagel model.""" - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline + model_class_name = "BagelPipeline" + # Skip the hack for loading the tf model config, + # because bagel doesn't use it. + uses_tf_config = False @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.transformer, "StableAudioDiTModel" + return pipeline.bagel, "Bagel" @staticmethod def install_hook(transformer: Any, hook: DataCollectionHook) -> None: @@ -104,52 +111,32 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: registry.register_hook(hook._HOOK_NAME, hook) -class Flux2Adapter: +class Flux2Adapter(DefaultAdapter): """Adapter for Flux2 model coefficient estimation.""" - @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> Flux2Pipeline: - """Load Flux2 pipeline for coefficient estimation.""" - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - od_config.model_class_name = "Flux2Pipeline" - - pipeline = Flux2Pipeline(od_config=od_config) - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline + model_class_name = "Flux2Pipeline" - @staticmethod - def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.transformer, pipeline.transformer.__class__.__name__ - @staticmethod - def install_hook(transformer: Any, hook: DataCollectionHook) -> None: - registry = HookRegistry.get_or_create(transformer) - registry.register_hook(hook._HOOK_NAME, hook) +class LongCatAdapter(DefaultAdapter): + """Adapter for LongCat Image - NOTE: currently this model needs the vLLM + context to be correctly configured to actually run the estimation, since it + uses vLLM norm layers etc. + """ + model_class_name = "LongCatImagePipeline" -class DefaultAdapter: - """Default adapter for standard diffusers pipelines.""" - @staticmethod - def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: - raise NotImplementedError("DefaultAdapter.load_pipeline not implemented") - - @staticmethod - def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.transformer, pipeline.transformer.__class__.__name__ +class StableAudioAdapter(DefaultAdapter): + """Adapter for Stable Audio Open 1.0 coefficient estimation.""" - @staticmethod - def install_hook(transformer: Any, hook: DataCollectionHook) -> None: - registry = HookRegistry.get_or_create(transformer) - registry.register_hook(hook._HOOK_NAME, hook) + model_class_name = "StableAudioPipeline" _MODEL_ADAPTERS: dict[str, type] = { "Bagel": BagelAdapter, "StableAudio": StableAudioAdapter, "Flux2": Flux2Adapter, + "LongCat": LongCatAdapter, } _EPSILON = 1e-6 @@ -196,7 +183,6 @@ def __init__( device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): - # Add validation here ⬇️ if model_type not in _MODEL_ADAPTERS: available_types = list(_MODEL_ADAPTERS.keys()) raise ValueError( @@ -205,7 +191,7 @@ def __init__( f"To add support for a new model, add an entry to _MODEL_ADAPTERS." ) - adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter) + adapter = _MODEL_ADAPTERS[model_type] self.pipeline = adapter.load_pipeline(model_path, device, dtype) self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline) self.hook = DataCollectionHook(self.transformer_type) diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index ecf3bfc1d3d..7efdd418e12 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -73,6 +73,8 @@ 3.20000000e00, -2.00000000e-02, ], + # LongCat Image transformer coefficients + "LongCatImageTransformer2DModel": [652.5980, -424.1615, 84.5526, -4.5923, 0.1694], } diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 84c237b60d5..d0da0d9df3f 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -19,10 +19,13 @@ import torch import torch.nn as nn +from vllm.logger import init_logger from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.platforms import current_omni_platform +logger = init_logger(__name__) + @dataclass class CacheContext: @@ -723,6 +726,105 @@ def postprocess(h): ) +def extract_longcat_context( + module: nn.Module, # LongCatImageTransformer2DModel + hidden_states, + timestep, + guidance, + encoder_hidden_states, + txt_ids, + img_ids, + **kwargs, +) -> CacheContext: + """Extract the cache context for LongCat Image. + + Similar to other extractors, this is currently the only code needed + for TeaCache support for LongCat image, and encapsulates preprocessing, + modulated input extraction, transformer execution, and postprocessing + logic. + + Args & kawrgs are identical to the inputs to LongCat Image's forward. + + Returns: + CacheContext with all information needed for generic caching + """ + # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + # 1. Model specific preprocessing + fwd_context = get_forward_context() + sp_size = module.parallel_config.sequence_parallel_size + if sp_size is not None and sp_size > 1: + # NOTE: For now, we set this to False on the forward context + # to be consistent with LongCat Image's current behavior when + # TeaCache is enabled. We do not need to reset it in post process + # since we should never split text embed in sp for this model. + fwd_context.split_text_embed_in_sp = False + + hidden_states = module.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + temb = module.time_embed(timestep, hidden_states.dtype) + encoder_hidden_states = module.context_embedder(encoder_hidden_states) + + # Compute RoPE embeddings via rope_preparer module + # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3) + # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention + txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids) + + # Reconstruct image_rotary_emb with chunked values + # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) + image_rotary_emb = ( + torch.cat([txt_cos, img_cos], dim=0), + torch.cat([txt_sin, img_sin], dim=0), + ) + + # 2. Extract the modulated output from the first mm-DiT block + first_block = module.transformer_blocks[0] + img_modulated = first_block.norm1(hidden_states, emb=temb)[0] + + # 3. Define the transformer execution + def run_transformer_blocks(): + """Execute all Longcat transformer blocks.""" + h = hidden_states + e = encoder_hidden_states + for block in module.transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for block in module.single_transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + # Hook expects hidden states to be first + return (h, e) + + # 4. Postprocessing + def postprocess(h): + """Apply Longcat-specific output postprocessing.""" + h = module.norm_out(h, temb) + output = module.proj_out(h) + return Transformer2DModelOutput(sample=output) + + # 5. Return the CacheContext + return CacheContext( + modulated_input=img_modulated, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) + + def extract_stable_audio_context( module: nn.Module, hidden_states: torch.Tensor, @@ -980,6 +1082,7 @@ def postprocess(h): "Flux2Klein": extract_flux2_klein_context, "StableAudioDiTModel": extract_stable_audio_context, "Flux2Transformer2DModel": extract_flux2_context, + "LongCatImageTransformer2DModel": extract_longcat_context, # Future models: # "FluxTransformer2DModel": extract_flux_context, # "CogVideoXTransformer3DModel": extract_cogvideox_context,