Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm_omni/diffusion/cache/teacache/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,21 @@ def extract_qwen_context(
# ============================================================================
# PREPROCESSING (Qwen-specific)
# ============================================================================
hidden_states = module.img_in(hidden_states)
# Call image_rope_prepare instead of img_in + pos_embed directly.
# This ensures the SequenceParallelSplitHook registered on image_rope_prepare
# fires when SP is enabled, correctly sharding hidden_states and vid_freqs.
hidden_states, vid_freqs, txt_freqs = module.image_rope_prepare(hidden_states, img_shapes, txt_seq_lens)
image_rotary_emb = (vid_freqs, txt_freqs)

timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype)

# Call modulate_index_prepare instead of handling timestep directly.
# For zero_cond_t=False: timestep unchanged, modulate_index=None.
# For zero_cond_t=True: timestep is doubled, modulate_index is created and
# sharded by the SequenceParallelSplitHook on modulate_index_prepare so that
# its sequence dimension matches the already-sharded hidden_states.
timestep, modulate_index = module.modulate_index_prepare(timestep, img_shapes)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep TeaCache extractor aligned with zero_cond_t batching

For Qwen checkpoints with zero_cond_t=True (the image-edit variants), ModulateIndexPrepare.forward() doubles timestep here (qwen_image_transformer.py:141-154), so temb becomes 2 * batch. The rest of extract_qwen_context() still consumes that embedding as if it were batch: block.img_norm1(hidden_states, img_mod1) is still called without modulate_index, and postprocess() never chunks temb back down before module.norm_out, unlike QwenImageTransformer2DModel.forward() (qwen_image_transformer.py:1062-1065). With TeaCache enabled, Qwen edit models will therefore fail on the first forward with a batch-dimension mismatch instead of running the transformer.

Useful? React with 👍 / 👎.


encoder_hidden_states = module.txt_norm(encoder_hidden_states)
encoder_hidden_states = module.txt_in(encoder_hidden_states)

Expand All @@ -202,8 +215,6 @@ def extract_qwen_context(
else module.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
)

image_rotary_emb = module.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)

# ============================================================================
# EXTRACT MODULATED INPUT (for cache decision)
# ============================================================================
Expand Down Expand Up @@ -249,6 +260,7 @@ def run_transformer_blocks():
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
modulate_index=modulate_index,
hidden_states_mask=hidden_states_mask,
)
return (h, e)
Expand Down
Loading