diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 7802979191..41ff4fa271 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -188,8 +188,21 @@ def extract_qwen_context( # ============================================================================ # PREPROCESSING (Qwen-specific) # ============================================================================ - hidden_states = module.img_in(hidden_states) + # Call image_rope_prepare instead of img_in + pos_embed directly. + # This ensures the SequenceParallelSplitHook registered on image_rope_prepare + # fires when SP is enabled, correctly sharding hidden_states and vid_freqs. + hidden_states, vid_freqs, txt_freqs = module.image_rope_prepare(hidden_states, img_shapes, txt_seq_lens) + image_rotary_emb = (vid_freqs, txt_freqs) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype) + + # Call modulate_index_prepare instead of handling timestep directly. + # For zero_cond_t=False: timestep unchanged, modulate_index=None. + # For zero_cond_t=True: timestep is doubled, modulate_index is created and + # sharded by the SequenceParallelSplitHook on modulate_index_prepare so that + # its sequence dimension matches the already-sharded hidden_states. + timestep, modulate_index = module.modulate_index_prepare(timestep, img_shapes) + encoder_hidden_states = module.txt_norm(encoder_hidden_states) encoder_hidden_states = module.txt_in(encoder_hidden_states) @@ -202,8 +215,6 @@ def extract_qwen_context( else module.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = module.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) - # ============================================================================ # EXTRACT MODULATED INPUT (for cache decision) # ============================================================================ @@ -249,6 +260,7 @@ def run_transformer_blocks(): temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=attention_kwargs, + modulate_index=modulate_index, hidden_states_mask=hidden_states_mask, ) return (h, e)