From afe43c00b0ee9a4c703d31d47c710c5485b111c1 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 24 Feb 2026 04:21:03 +0000 Subject: [PATCH 01/18] first pass at longcat teacache Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/config.py | 8 ++ .../diffusion/cache/teacache/extractors.py | 79 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index ecf3bfc1d3d..d20cff0903e 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -53,6 +53,14 @@ 68.05368574596551, -12.281286412689623, 1.0733905006198015, + # LongCat Image transformer coefficients + # Using Qwen-image to start + "LongCatImageTransformer2DModel": [ + -4.50000000e02, + 2.80000000e02, + -4.50000000e01, + 3.20000000e00, + -2.00000000e-02, ], # HunyuanImage3 pipeline coefficients # Calibrated via polyfit on 3920 data points (80 prompts × 49 steps). diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 84c237b60d5..63af9890446 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -722,6 +722,84 @@ def postprocess(h): }, ) +def extract_longcat_context( + module: nn.Module, # LongCatImageTransformer2DModel + hidden_states, + timestep, + guidance, + encoder_hidden_states, + txt_ids, + img_ids, + **kwargs, +) -> CacheContext: + """""" + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + # 1. Model specific preprocessing + # TODO - fix sequence parallelism + sp_size = module.parallel_config.sequence_parallel_size + get_forward_context().sequence_parallel_size = sp_size + + hidden_states = module.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + temb = module.time_embed(timestep, hidden_states.dtype) + encoder_hidden_states = module.context_embedder(encoder_hidden_states) + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = module.pos_embed(ids) + + # 2. Extract the modulated output from the first mm-DiT block + first_block = module.transformer_blocks[0] + hs, _ = first_block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + img_modulated = first_block.norm1(hs, emb=temb)[0] + + # 3. Define the transformer execution + def run_transformer_blocks(): + """Execute all Longcat transformer blocks.""" + h = hidden_states + e = encoder_hidden_states + for block in module.transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for block in module.single_transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + return (h, e) + + # 4. Postprocessing + def postprocess(h): + """Apply Longcat-specific output postprocessing.""" + h = module.norm_out(h, temb) + output = module.proj_out(h) + return Transformer2DModelOutput(sample=output) + + # 5. Return the CacheContext + return CacheContext( + modulated_input=img_modulated, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) + def extract_stable_audio_context( module: nn.Module, @@ -980,6 +1058,7 @@ def postprocess(h): "Flux2Klein": extract_flux2_klein_context, "StableAudioDiTModel": extract_stable_audio_context, "Flux2Transformer2DModel": extract_flux2_context, + "LongCatImageTransformer2DModel": extract_longcat_context, # Future models: # "FluxTransformer2DModel": extract_flux_context, # "CogVideoXTransformer3DModel": extract_cogvideox_context, From 422ec7ec1664caea0c46d5f9af7c171cf47a254f Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 27 Feb 2026 17:10:07 +0000 Subject: [PATCH 02/18] update longcat image teacache coefficients Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index d20cff0903e..3a833f36437 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -56,11 +56,11 @@ # LongCat Image transformer coefficients # Using Qwen-image to start "LongCatImageTransformer2DModel": [ - -4.50000000e02, - 2.80000000e02, - -4.50000000e01, - 3.20000000e00, - -2.00000000e-02, + -90.25, + 95.12, + 31.15, + 4.85, + -0.017, ], # HunyuanImage3 pipeline coefficients # Calibrated via polyfit on 3920 data points (80 prompts × 49 steps). From f13bc1131b52ea9806bdaf74ec926c43aa008adb Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 27 Feb 2026 17:13:50 +0000 Subject: [PATCH 03/18] add teacache adapter for longcat Signed-off-by: Alex Brooks --- .../cache/teacache/coefficient_estimator.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py index baec21c2762..e9ae01cc9c0 100644 --- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py +++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py @@ -14,6 +14,7 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline +from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import LongCatImagePipeline from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -36,6 +37,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: ctx = self.extractor_fn(module, *args, **kwargs) + # NOTE: We upcast to float32 to also handle bfloat16. modulated_input_cpu = ctx.modulated_input.detach().float().cpu().numpy() outputs = ctx.run_transformer_blocks() @@ -64,7 +66,8 @@ def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = to pipeline = BagelPipeline(od_config=od_config) loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) + with set_default_torch_dtype(od_config.dtype): + loader.load_weights(pipeline) pipeline.to(device) return pipeline @@ -146,10 +149,34 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: registry.register_hook(hook._HOOK_NAME, hook) +class LongCatAdapter(DefaultAdapter): + """Adapter for LongCat Image - NOTE: currently this models needs the vLLM + context to be correctly configured to actually run the estimation, since it + uses vLLM norm layers etc. + """ + + @staticmethod + def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: + + od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) + od_config.model_class_name = "LongCatImagePipeline" + + # TODO - we should use load_model in all adapters since the + # dtype is already specified in the Diffusion config anyway + with set_default_torch_dtype(dtype): + pipeline = LongCatImagePipeline(od_config=od_config) + + loader = DiffusersPipelineLoader(LoadConfig()) + loader.load_weights(pipeline) + pipeline.to(device) + return pipeline + + _MODEL_ADAPTERS: dict[str, type] = { "Bagel": BagelAdapter, "StableAudio": StableAudioAdapter, "Flux2": Flux2Adapter, + "LongCat": LongCatAdapter, } _EPSILON = 1e-6 From f6fb7e902cb4b5761166992d5ac8d769283eae48 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 27 Feb 2026 17:17:34 +0000 Subject: [PATCH 04/18] add workaround for vllm context to docs Signed-off-by: Alex Brooks --- docs/design/feature/teacache.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/design/feature/teacache.md b/docs/design/feature/teacache.md index 9fa315cee77..b939366741f 100644 --- a/docs/design/feature/teacache.md +++ b/docs/design/feature/teacache.md @@ -326,9 +326,30 @@ for prompt in tqdm(prompts, desc="Collecting data"): # Estimate coefficients coeffs = estimator.estimate(poly_order=4) -print(f"Estimated coefficients: {coeffs.tolist()}") +print(f"Estimated coefficients: {coeffs}") ``` +Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation. +```python +from vllm_omni.diffusion.forward_context import set_forward_context +... + +if __name__ == "__main__": + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8192" + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + vllm_config = VllmConfig() + init_distributed_environment() + initialize_model_parallel() + + with set_forward_context(vllm_config): + +``` + + **Data Statistics Guide:** | Metric | Good Range | Warning Signs | From f153aad104b7df70cb5e0cdaa5f0740c56336ae8 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 27 Feb 2026 17:47:28 +0000 Subject: [PATCH 05/18] copy sequence parallel support to teacache Signed-off-by: Alex Brooks --- .../diffusion/cache/teacache/extractors.py | 70 ++++++++++++++++++- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 63af9890446..26e10980e3c 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -19,10 +19,17 @@ import torch import torch.nn as nn +from vllm.logger import init_logger +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, +) from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.platforms import current_omni_platform +logger = init_logger(__name__) + @dataclass class CacheContext: @@ -734,12 +741,33 @@ def extract_longcat_context( ) -> CacheContext: """""" from diffusers.models.modeling_outputs import Transformer2DModelOutput + # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward # 1. Model specific preprocessing - # TODO - fix sequence parallelism sp_size = module.parallel_config.sequence_parallel_size + # Store SP size in forward context for sub-modules to access get_forward_context().sequence_parallel_size = sp_size + if sp_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + original_shape = hidden_states.shape + hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] + # LongCat uses dual-stream (text + image) with joint attention + # Text embeddings should be replicated across SP ranks for correctness + get_forward_context().split_text_embed_in_sp = False + # Debug log (only first forward) + if not hasattr(module, "_sp_forward_logged"): + module._sp_forward_logged = True + logger.info( + f"[LongCat Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " + f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" + ) + else: + if not hasattr(module, "_sp_forward_logged"): + module._sp_forward_logged = True + logger.info(f"[LongCat Transformer] SP disabled: sp_size={sp_size}") + hidden_states = module.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -748,7 +776,45 @@ def extract_longcat_context( encoder_hidden_states = module.context_embedder(encoder_hidden_states) ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = module.pos_embed(ids) + + if current_omni_platform.is_npu(): + freqs_cos, freqs_sin = module.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = module.pos_embed(ids) + + # SP: Chunk RoPE embeddings along sequence dimension + if module.parallel_config.sequence_parallel_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_cos, freqs_sin = image_rotary_emb + txt_len = txt_ids.shape[0] + + # Split RoPE into text and image portions + # txt_freqs: (txt_seq_len, head_dim) - keep full for all ranks + # img_freqs: (img_seq_len, head_dim) -> (img_seq_len // SP, head_dim) + txt_freqs_cos = freqs_cos[:txt_len] + txt_freqs_sin = freqs_sin[:txt_len] + img_freqs_cos = freqs_cos[txt_len:] + img_freqs_sin = freqs_sin[txt_len:] + + # Chunk image RoPE for each SP rank + # img_freqs_cos: (img_seq_len // SP, head_dim) + # img_freqs_sin: (img_seq_len // SP, head_dim) + img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank] + img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank] + + # Optionally chunk text RoPE if split_text_embed_in_sp is True + if get_forward_context().split_text_embed_in_sp: + txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank] + txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank] + + # Reconstruct image_rotary_emb with chunked values + # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) + image_rotary_emb = ( + torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), + torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), + ) # 2. Extract the modulated output from the first mm-DiT block first_block = module.transformer_blocks[0] From 21edcd45f19b2f8c8eff7fe6ecd5707c8d297b57 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 27 Feb 2026 18:20:47 +0000 Subject: [PATCH 06/18] fix block output Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/extractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 26e10980e3c..156220aa873 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -818,7 +818,7 @@ def extract_longcat_context( # 2. Extract the modulated output from the first mm-DiT block first_block = module.transformer_blocks[0] - hs, _ = first_block( + _, hs = first_block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, @@ -846,7 +846,7 @@ def run_transformer_blocks(): temb=temb, image_rotary_emb=image_rotary_emb, ) - + # Hook expects hidden states to be first return (h, e) # 4. Postprocessing From 82f3bb6131b669f7afd3f9733f957f58bc17739f Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 27 Feb 2026 18:40:42 +0000 Subject: [PATCH 07/18] correct coefficients Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/config.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index 3a833f36437..8b9dd5ffd0a 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -53,14 +53,6 @@ 68.05368574596551, -12.281286412689623, 1.0733905006198015, - # LongCat Image transformer coefficients - # Using Qwen-image to start - "LongCatImageTransformer2DModel": [ - -90.25, - 95.12, - 31.15, - 4.85, - -0.017, ], # HunyuanImage3 pipeline coefficients # Calibrated via polyfit on 3920 data points (80 prompts × 49 steps). @@ -81,6 +73,8 @@ 3.20000000e00, -2.00000000e-02, ], + # LongCat Image transformer coefficients + "LongCatImageTransformer2DModel": [116.500, -58.959, 6.909, 1.175, 0.108], } From 5c4a3cd9c6030431b6a85f14f7fc3a32f67d2c8b Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sat, 14 Mar 2026 21:26:21 +0000 Subject: [PATCH 08/18] add docstring Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/extractors.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 156220aa873..5c0f8ea99a5 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -739,9 +739,20 @@ def extract_longcat_context( img_ids, **kwargs, ) -> CacheContext: - """""" - from diffusers.models.modeling_outputs import Transformer2DModelOutput + """Extract the cache context for LongCat Image. + + Similar to other extractors, this is currently the only code needed + for TeaCache support for LongCat image, and encapsulates preprocessing, + modulated input extraction, transformer execution, and postprocessing + logic. + + Args & kawrgs are identical to the inputs to LongCat Image's forward. + + Returns: + CacheContext with all information needed for generic caching + """ # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward + from diffusers.models.modeling_outputs import Transformer2DModelOutput # 1. Model specific preprocessing sp_size = module.parallel_config.sequence_parallel_size From acba1ddd14e48e39546b7eb81c5eb5a6d0366c79 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sat, 14 Mar 2026 21:51:00 +0000 Subject: [PATCH 09/18] remove sequence parallel hacks in teacache extractor Signed-off-by: Alex Brooks --- .../diffusion/cache/teacache/extractors.py | 79 +++---------------- 1 file changed, 13 insertions(+), 66 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 5c0f8ea99a5..756f535900a 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -21,12 +21,7 @@ import torch.nn as nn from vllm.logger import init_logger -from vllm_omni.diffusion.distributed.parallel_state import ( - get_sequence_parallel_rank, - get_sequence_parallel_world_size, -) from vllm_omni.diffusion.forward_context import get_forward_context -from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) @@ -755,29 +750,10 @@ def extract_longcat_context( from diffusers.models.modeling_outputs import Transformer2DModelOutput # 1. Model specific preprocessing + fwd_context = get_forward_context() sp_size = module.parallel_config.sequence_parallel_size - # Store SP size in forward context for sub-modules to access - get_forward_context().sequence_parallel_size = sp_size - - if sp_size > 1: - sp_world_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() - original_shape = hidden_states.shape - hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] - # LongCat uses dual-stream (text + image) with joint attention - # Text embeddings should be replicated across SP ranks for correctness - get_forward_context().split_text_embed_in_sp = False - # Debug log (only first forward) - if not hasattr(module, "_sp_forward_logged"): - module._sp_forward_logged = True - logger.info( - f"[LongCat Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " - f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" - ) - else: - if not hasattr(module, "_sp_forward_logged"): - module._sp_forward_logged = True - logger.info(f"[LongCat Transformer] SP disabled: sp_size={sp_size}") + if sp_size is not None and sp_size > 1: + fwd_context.split_text_embed_in_sp = False hidden_states = module.x_embedder(hidden_states) @@ -786,46 +762,17 @@ def extract_longcat_context( temb = module.time_embed(timestep, hidden_states.dtype) encoder_hidden_states = module.context_embedder(encoder_hidden_states) - ids = torch.cat((txt_ids, img_ids), dim=0) + # Compute RoPE embeddings via rope_preparer module + # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3) + # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention + txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids) - if current_omni_platform.is_npu(): - freqs_cos, freqs_sin = module.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = module.pos_embed(ids) - - # SP: Chunk RoPE embeddings along sequence dimension - if module.parallel_config.sequence_parallel_size > 1: - sp_world_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() - freqs_cos, freqs_sin = image_rotary_emb - txt_len = txt_ids.shape[0] - - # Split RoPE into text and image portions - # txt_freqs: (txt_seq_len, head_dim) - keep full for all ranks - # img_freqs: (img_seq_len, head_dim) -> (img_seq_len // SP, head_dim) - txt_freqs_cos = freqs_cos[:txt_len] - txt_freqs_sin = freqs_sin[:txt_len] - img_freqs_cos = freqs_cos[txt_len:] - img_freqs_sin = freqs_sin[txt_len:] - - # Chunk image RoPE for each SP rank - # img_freqs_cos: (img_seq_len // SP, head_dim) - # img_freqs_sin: (img_seq_len // SP, head_dim) - img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank] - img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank] - - # Optionally chunk text RoPE if split_text_embed_in_sp is True - if get_forward_context().split_text_embed_in_sp: - txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank] - txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank] - - # Reconstruct image_rotary_emb with chunked values - # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) - image_rotary_emb = ( - torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), - torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), - ) + # Reconstruct image_rotary_emb with chunked values + # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) + image_rotary_emb = ( + torch.cat([txt_cos, img_cos], dim=0), + torch.cat([txt_sin, img_sin], dim=0), + ) # 2. Extract the modulated output from the first mm-DiT block first_block = module.transformer_blocks[0] From eac1d4557f0680eb2c46ac762b2bfb4043435e33 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sat, 14 Mar 2026 22:13:46 +0000 Subject: [PATCH 10/18] fix img modulation in longcat image teacache Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/extractors.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 756f535900a..17312591410 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -776,13 +776,7 @@ def extract_longcat_context( # 2. Extract the modulated output from the first mm-DiT block first_block = module.transformer_blocks[0] - _, hs = first_block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - img_modulated = first_block.norm1(hs, emb=temb)[0] + img_modulated = first_block.norm1(hidden_states, emb=temb)[0] # 3. Define the transformer execution def run_transformer_blocks(): From c55112cb12c7fef0e75dfef60bfc4e37446dc673 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sat, 14 Mar 2026 23:54:53 +0000 Subject: [PATCH 11/18] clean up coefficient dtype handling Signed-off-by: Alex Brooks --- .../cache/teacache/coefficient_estimator.py | 39 +++++++++---------- vllm_omni/diffusion/cache/teacache/config.py | 2 +- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py index e9ae01cc9c0..fda9d0b8809 100644 --- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py +++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import Any import numpy as np import torch from vllm.config import LoadConfig -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.transformers_utils.config import get_hf_file_to_dict from vllm_omni.diffusion.cache.teacache.extractors import get_extractor -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.hooks import HookRegistry, ModelHook from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline @@ -60,16 +61,14 @@ class BagelAdapter: """Adapter for Bagel model.""" @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline: + def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> BagelPipeline: od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) od_config.model_class_name = "BagelPipeline" + # No hack needed for tf_model_config since Bagel doesn't use it - pipeline = BagelPipeline(od_config=od_config) - loader = DiffusersPipelineLoader(LoadConfig()) - with set_default_torch_dtype(od_config.dtype): - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline + loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config) + # load_model will handle dtypes / device placement, put in .eval() mode + return loader.load_model(od_config=od_config, load_device=device) @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: @@ -156,20 +155,20 @@ class LongCatAdapter(DefaultAdapter): """ @staticmethod - def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: - + def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> LongCatImagePipeline: od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) od_config.model_class_name = "LongCatImagePipeline" + # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig + # instead of OmniDiffusion and remove the manual population here + tf_config_dict = get_hf_file_to_dict( + os.path.join("transformer", "config.json"), + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) - # TODO - we should use load_model in all adapters since the - # dtype is already specified in the Diffusion config anyway - with set_default_torch_dtype(dtype): - pipeline = LongCatImagePipeline(od_config=od_config) - - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline + loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config) + # load_model will handle dtypes / device placement, put in .eval() mode + return loader.load_model(od_config=od_config, load_device=device) _MODEL_ADAPTERS: dict[str, type] = { diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index 8b9dd5ffd0a..7efdd418e12 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -74,7 +74,7 @@ -2.00000000e-02, ], # LongCat Image transformer coefficients - "LongCatImageTransformer2DModel": [116.500, -58.959, 6.909, 1.175, 0.108], + "LongCatImageTransformer2DModel": [652.5980, -424.1615, 84.5526, -4.5923, 0.1694], } From 7a8e8fbb8bd1746d3b175114716383cdde00f9fa Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sat, 14 Mar 2026 23:59:51 +0000 Subject: [PATCH 12/18] teacache doc tweaks Signed-off-by: Alex Brooks --- docs/design/feature/teacache.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/design/feature/teacache.md b/docs/design/feature/teacache.md index b939366741f..8577cff1f05 100644 --- a/docs/design/feature/teacache.md +++ b/docs/design/feature/teacache.md @@ -332,6 +332,11 @@ print(f"Estimated coefficients: {coeffs}") Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation. ```python from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) +from vllm.config import VllmConfig ... if __name__ == "__main__": @@ -345,6 +350,12 @@ if __name__ == "__main__": init_distributed_environment() initialize_model_parallel() + # NOTE: you may have to pass an initialized OmniDiffusionConfig as a kwarg + # here to make current sp checks happy; if this is the case, just create one + # .from_kwargs() with the model name to get around this check for now, + # since your estimator subclass should handle the actual model configuration. + # + # This will be cleaned up in the future with set_forward_context(vllm_config): ``` From 4815cbc7e7229a1a83dc951bd4de5333be662aa0 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Mar 2026 04:00:41 +0000 Subject: [PATCH 13/18] add comment about split_text_embed_in_sp Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/extractors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 17312591410..add1003bec5 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -753,6 +753,10 @@ def extract_longcat_context( fwd_context = get_forward_context() sp_size = module.parallel_config.sequence_parallel_size if sp_size is not None and sp_size > 1: + # NOTE: For now, we set this to False on the forward context + # to be consistent with LongCat Image's current behavior when + # TeaCache is enabled. We do not need to reset it in post process + # since we should never split text embed in sp for this model. fwd_context.split_text_embed_in_sp = False hidden_states = module.x_embedder(hidden_states) From d7a19a39014ed5d05eccb444fb1b1d9cde9a8a4c Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 23 Mar 2026 04:47:44 +0000 Subject: [PATCH 14/18] fmt Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/extractors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index add1003bec5..cd78d20238a 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -724,6 +724,7 @@ def postprocess(h): }, ) + def extract_longcat_context( module: nn.Module, # LongCatImageTransformer2DModel hidden_states, From 1a8d1bae73ea736fdbac8dce510834818b43bfdd Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 26 Mar 2026 23:18:51 +0000 Subject: [PATCH 15/18] make teacache estimator common Signed-off-by: Alex Brooks --- .../cache/teacache/coefficient_estimator.py | 80 +++++++++++++------ 1 file changed, 56 insertions(+), 24 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py index fda9d0b8809..4da77cc02d1 100644 --- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py +++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py @@ -13,10 +13,13 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.hooks import HookRegistry, ModelHook from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +<<<<<<< HEAD from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import LongCatImagePipeline from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline +======= +>>>>>>> 79e29eaa (make teacache estimator common) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -57,14 +60,31 @@ def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]: return list(self.current_trajectory) -class BagelAdapter: - """Adapter for Bagel model.""" +class DefaultAdapter: + """Default adapter for standard diffusers pipelines.""" - @staticmethod - def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> BagelPipeline: - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - od_config.model_class_name = "BagelPipeline" - # No hack needed for tf_model_config since Bagel doesn't use it + model_class_name = None + uses_tf_config = True + + @classmethod + def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any: + if cls.model_class_name is None: + raise ValueError("Adapter doesn't have a set class name.") + + od_config = OmniDiffusionConfig.from_kwargs( + model_class_name=cls.model_class_name, + model=model_path, + dtype=dtype, + ) + + if cls.uses_tf_config: + # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig + # instead of OmniDiffusion and remove the manual population here + tf_config_dict = get_hf_file_to_dict( + os.path.join("transformer", "config.json"), + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config) # load_model will handle dtypes / device placement, put in .eval() mode @@ -72,6 +92,7 @@ def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> BagelPipe @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: +<<<<<<< HEAD return pipeline.bagel, "Bagel" @staticmethod @@ -140,6 +161,8 @@ def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: +======= +>>>>>>> 79e29eaa (make teacache estimator common) return pipeline.transformer, pipeline.transformer.__class__.__name__ @staticmethod @@ -148,27 +171,37 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: registry.register_hook(hook._HOOK_NAME, hook) +class BagelAdapter(DefaultAdapter): + """Adapter for Bagel model.""" + + model_class_name = "BagelPipeline" + # Skip the hack for loading the tf model config, + # because bagel doesn't use it. + uses_tf_config = False + + @staticmethod + def get_transformer(pipeline: Any) -> tuple[Any, str]: + return pipeline.bagel, "Bagel" + + @staticmethod + def install_hook(transformer: Any, hook: DataCollectionHook) -> None: + registry = HookRegistry.get_or_create(transformer) + registry.register_hook(hook._HOOK_NAME, hook) + + class LongCatAdapter(DefaultAdapter): - """Adapter for LongCat Image - NOTE: currently this models needs the vLLM + """Adapter for LongCat Image - NOTE: currently this model needs the vLLM context to be correctly configured to actually run the estimation, since it uses vLLM norm layers etc. """ - @staticmethod - def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> LongCatImagePipeline: - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - od_config.model_class_name = "LongCatImagePipeline" - # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig - # instead of OmniDiffusion and remove the manual population here - tf_config_dict = get_hf_file_to_dict( - os.path.join("transformer", "config.json"), - od_config.model, - ) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + model_class_name = "LongCatImagePipeline" - loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config) - # load_model will handle dtypes / device placement, put in .eval() mode - return loader.load_model(od_config=od_config, load_device=device) + +class StableAudioAdapter(DefaultAdapter): + """Adapter for Stable Audio Open 1.0 coefficient estimation.""" + + model_class_name = "StableAudioPipeline" _MODEL_ADAPTERS: dict[str, type] = { @@ -222,7 +255,6 @@ def __init__( device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): - # Add validation here ⬇️ if model_type not in _MODEL_ADAPTERS: available_types = list(_MODEL_ADAPTERS.keys()) raise ValueError( @@ -231,7 +263,7 @@ def __init__( f"To add support for a new model, add an entry to _MODEL_ADAPTERS." ) - adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter) + adapter = _MODEL_ADAPTERS[model_type] self.pipeline = adapter.load_pipeline(model_path, device, dtype) self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline) self.hook = DataCollectionHook(self.transformer_type) From 567b33687f174dd98e02732e579da54da616335b Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 1 Apr 2026 09:24:42 +0000 Subject: [PATCH 16/18] update docs Signed-off-by: Alex Brooks --- docs/user_guide/diffusion_features.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 4e7003cce37..d70fdd9df7e 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -115,8 +115,8 @@ The following tables show which models support each feature: | **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| **LongCat-Image-Edit** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **MagiHuman** | ❌ | ❌ | ❌ | ❓ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | From eebbf2f2e5d60bcbdd7a9048cdc90623129cb9db Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 17 Apr 2026 16:46:57 +0000 Subject: [PATCH 17/18] rebase flux2 Signed-off-by: Alex Brooks --- .../cache/teacache/coefficient_estimator.py | 84 ++----------------- 1 file changed, 6 insertions(+), 78 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py index 4da77cc02d1..38c805c28db 100644 --- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py +++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py @@ -13,13 +13,6 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.hooks import HookRegistry, ModelHook from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -<<<<<<< HEAD -from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline -from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline -from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import LongCatImagePipeline -from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline -======= ->>>>>>> 79e29eaa (make teacache estimator common) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -92,77 +85,6 @@ def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any: @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: -<<<<<<< HEAD - return pipeline.bagel, "Bagel" - - @staticmethod - def install_hook(transformer: Any, hook: DataCollectionHook) -> None: - registry = HookRegistry.get_or_create(transformer) - registry.register_hook(hook._HOOK_NAME, hook) - - -class StableAudioAdapter: - """Adapter for Stable Audio Open 1.0 coefficient estimation.""" - - @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.float16) -> Any: - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - - # Strictly necessary because we bypass loader.load_model() - with set_default_torch_dtype(dtype): - pipeline = StableAudioPipeline(od_config=od_config) - - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline - - @staticmethod - def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.transformer, "StableAudioDiTModel" - - @staticmethod - def install_hook(transformer: Any, hook: DataCollectionHook) -> None: - registry = HookRegistry.get_or_create(transformer) - registry.register_hook(hook._HOOK_NAME, hook) - - -class Flux2Adapter: - """Adapter for Flux2 model coefficient estimation.""" - - @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> Flux2Pipeline: - """Load Flux2 pipeline for coefficient estimation.""" - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - od_config.model_class_name = "Flux2Pipeline" - - pipeline = Flux2Pipeline(od_config=od_config) - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline - - @staticmethod - def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.transformer, pipeline.transformer.__class__.__name__ - - @staticmethod - def install_hook(transformer: Any, hook: DataCollectionHook) -> None: - registry = HookRegistry.get_or_create(transformer) - registry.register_hook(hook._HOOK_NAME, hook) - - -class DefaultAdapter: - """Default adapter for standard diffusers pipelines.""" - - @staticmethod - def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: - raise NotImplementedError("DefaultAdapter.load_pipeline not implemented") - - @staticmethod - def get_transformer(pipeline: Any) -> tuple[Any, str]: -======= ->>>>>>> 79e29eaa (make teacache estimator common) return pipeline.transformer, pipeline.transformer.__class__.__name__ @staticmethod @@ -189,6 +111,12 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: registry.register_hook(hook._HOOK_NAME, hook) +class Flux2Adapter(DefaultAdapter): + """Adapter for Flux2 model coefficient estimation.""" + + model_class_name = "Flux2Pipeline" + + class LongCatAdapter(DefaultAdapter): """Adapter for LongCat Image - NOTE: currently this model needs the vLLM context to be correctly configured to actually run the estimation, since it From 537a0f55a01c7f0ddbcc112233ff067bde7cb32f Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 17 Apr 2026 17:09:10 +0000 Subject: [PATCH 18/18] fix flux2 import Signed-off-by: Alex Brooks --- vllm_omni/diffusion/cache/teacache/extractors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index cd78d20238a..d0da0d9df3f 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -22,6 +22,7 @@ from vllm.logger import init_logger from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__)