-
Notifications
You must be signed in to change notification settings - Fork 959
Add Teacache Support for LongCat Image #1487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
afe43c0
422ec7e
f13bc11
f6fb7e9
f153aad
21edcd4
82f3bb6
5c4a3cd
acba1dd
eac1d45
c55112c
7a8e8fb
4815cbc
d7a19a3
1a8d1ba
567b336
eebbf2f
537a0f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,10 +19,13 @@ | |
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from vllm.logger import init_logger | ||
|
|
||
| from vllm_omni.diffusion.forward_context import get_forward_context | ||
| from vllm_omni.platforms import current_omni_platform | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class CacheContext: | ||
|
|
@@ -723,6 +726,105 @@ def postprocess(h): | |
| ) | ||
|
|
||
|
|
||
| def extract_longcat_context( | ||
| module: nn.Module, # LongCatImageTransformer2DModel | ||
| hidden_states, | ||
| timestep, | ||
| guidance, | ||
| encoder_hidden_states, | ||
| txt_ids, | ||
| img_ids, | ||
| **kwargs, | ||
| ) -> CacheContext: | ||
| """Extract the cache context for LongCat Image. | ||
|
|
||
| Similar to other extractors, this is currently the only code needed | ||
| for TeaCache support for LongCat image, and encapsulates preprocessing, | ||
| modulated input extraction, transformer execution, and postprocessing | ||
| logic. | ||
|
|
||
| Args & kawrgs are identical to the inputs to LongCat Image's forward. | ||
|
|
||
| Returns: | ||
| CacheContext with all information needed for generic caching | ||
| """ | ||
| # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward | ||
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | ||
|
alex-jw-brooks marked this conversation as resolved.
|
||
|
|
||
| # 1. Model specific preprocessing | ||
|
alex-jw-brooks marked this conversation as resolved.
|
||
| fwd_context = get_forward_context() | ||
| sp_size = module.parallel_config.sequence_parallel_size | ||
| if sp_size is not None and sp_size > 1: | ||
| # NOTE: For now, we set this to False on the forward context | ||
| # to be consistent with LongCat Image's current behavior when | ||
| # TeaCache is enabled. We do not need to reset it in post process | ||
| # since we should never split text embed in sp for this model. | ||
| fwd_context.split_text_embed_in_sp = False | ||
|
|
||
| hidden_states = module.x_embedder(hidden_states) | ||
|
Comment on lines
+756
to
+764
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the SP case ( Useful? React with 👍 / 👎.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Useful, but I think there are larger underlying problems in SP for this model at the moment (see #1556). I will investigate the fix for that as well, but see the same error with & without TeaCache at the moment, so open to any direction for how to handle it on this PR |
||
|
|
||
| timestep = timestep.to(hidden_states.dtype) * 1000 | ||
|
|
||
| temb = module.time_embed(timestep, hidden_states.dtype) | ||
| encoder_hidden_states = module.context_embedder(encoder_hidden_states) | ||
|
|
||
| # Compute RoPE embeddings via rope_preparer module | ||
| # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3) | ||
| # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention | ||
| txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids) | ||
|
|
||
| # Reconstruct image_rotary_emb with chunked values | ||
| # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) | ||
| image_rotary_emb = ( | ||
| torch.cat([txt_cos, img_cos], dim=0), | ||
| torch.cat([txt_sin, img_sin], dim=0), | ||
| ) | ||
|
|
||
| # 2. Extract the modulated output from the first mm-DiT block | ||
| first_block = module.transformer_blocks[0] | ||
| img_modulated = first_block.norm1(hidden_states, emb=temb)[0] | ||
|
|
||
| # 3. Define the transformer execution | ||
| def run_transformer_blocks(): | ||
| """Execute all Longcat transformer blocks.""" | ||
| h = hidden_states | ||
| e = encoder_hidden_states | ||
| for block in module.transformer_blocks: | ||
| e, h = block( | ||
|
alex-jw-brooks marked this conversation as resolved.
|
||
| hidden_states=h, | ||
| encoder_hidden_states=e, | ||
| temb=temb, | ||
| image_rotary_emb=image_rotary_emb, | ||
| ) | ||
|
|
||
| for block in module.single_transformer_blocks: | ||
| e, h = block( | ||
| hidden_states=h, | ||
| encoder_hidden_states=e, | ||
| temb=temb, | ||
| image_rotary_emb=image_rotary_emb, | ||
| ) | ||
| # Hook expects hidden states to be first | ||
| return (h, e) | ||
|
|
||
| # 4. Postprocessing | ||
| def postprocess(h): | ||
| """Apply Longcat-specific output postprocessing.""" | ||
| h = module.norm_out(h, temb) | ||
| output = module.proj_out(h) | ||
| return Transformer2DModelOutput(sample=output) | ||
|
|
||
| # 5. Return the CacheContext | ||
| return CacheContext( | ||
| modulated_input=img_modulated, | ||
| hidden_states=hidden_states, | ||
| encoder_hidden_states=encoder_hidden_states, | ||
| temb=temb, | ||
| run_transformer_blocks=run_transformer_blocks, | ||
| postprocess=postprocess, | ||
| ) | ||
|
|
||
|
|
||
| def extract_stable_audio_context( | ||
| module: nn.Module, | ||
| hidden_states: torch.Tensor, | ||
|
|
@@ -980,6 +1082,7 @@ def postprocess(h): | |
| "Flux2Klein": extract_flux2_klein_context, | ||
| "StableAudioDiTModel": extract_stable_audio_context, | ||
| "Flux2Transformer2DModel": extract_flux2_context, | ||
| "LongCatImageTransformer2DModel": extract_longcat_context, | ||
| # Future models: | ||
| # "FluxTransformer2DModel": extract_flux_context, | ||
| # "CogVideoXTransformer3DModel": extract_cogvideox_context, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also be wrapped with
set_default_torch_dtype(od_config.dtype)likeBagelAdapter.load_pipelinewas updated to do above?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had actually added a
set_default_torch_dtypearound the call to the load pipeline on the adapter instead of just putting it around the one line 🙂 the better way to do this isbecause
load_modelwill handle the device placement, put the model in eval mode, and handle the dtypes from the diffusion config. Updated both to avoid managing default dtypes manually and made sure the bagel one still runs