From 16e724e94cbb16561c101cf0e6fcaf645b2c951c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Mar 2026 22:32:21 +0800 Subject: [PATCH 1/2] fix qwen-image Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/cache/teacache/extractors.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 78029791916..088f41ad8f2 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -188,8 +188,21 @@ def extract_qwen_context( # ============================================================================ # PREPROCESSING (Qwen-specific) # ============================================================================ - hidden_states = module.img_in(hidden_states) + # Call image_rope_prepare instead of img_in + pos_embed directly. + # This ensures the SequenceParallelSplitHook registered on image_rope_prepare + # fires when SP is enabled, correctly sharding hidden_states and vid_freqs. + hidden_states, vid_freqs, txt_freqs = module.image_rope_prepare(hidden_states, img_shapes, txt_seq_lens) + image_rotary_emb = (vid_freqs, txt_freqs) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype) + + # Call modulate_index_prepare instead of handling timestep directly. + # For zero_cond_t=False: timestep unchanged, modulate_index=None. + # For zero_cond_t=True: timestep is doubled, modulate_index is created and + # sharded by the SequenceParallelSplitHook on modulate_index_prepare so that + # its sequence dimension matches the already-sharded hidden_states. + timestep, modulate_index = module.modulate_index_prepare(timestep, img_shapes) + encoder_hidden_states = module.txt_norm(encoder_hidden_states) encoder_hidden_states = module.txt_in(encoder_hidden_states) @@ -202,13 +215,14 @@ def extract_qwen_context( else module.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = module.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) - # ============================================================================ # EXTRACT MODULATED INPUT (for cache decision) # ============================================================================ block = module.transformer_blocks[0] - img_mod_params = block.img_mod(temb) + # For zero_cond_t=True, temb has 2x batch size (doubled timestep). + # Use only the first half for the modulated input extraction. + temb_for_mod = temb[: temb.shape[0] // 2] if module.zero_cond_t else temb + img_mod_params = block.img_mod(temb_for_mod) img_mod1, _ = img_mod_params.chunk(2, dim=-1) img_modulated, _ = block.img_norm1(hidden_states, img_mod1) @@ -249,6 +263,7 @@ def run_transformer_blocks(): temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=attention_kwargs, + modulate_index=modulate_index, hidden_states_mask=hidden_states_mask, ) return (h, e) @@ -260,7 +275,9 @@ def run_transformer_blocks(): def postprocess(h): """Apply Qwen-specific output postprocessing.""" - h = module.norm_out(h, temb) + # For zero_cond_t=True, temb has 2x batch size; use only the first half. + pp_temb = temb.chunk(2, dim=0)[0] if module.zero_cond_t else temb + h = module.norm_out(h, pp_temb) output = module.proj_out(h) if not return_dict: return (output,) From 23bda666b73974b3ff5e46f9a095b8209024faaa Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Mar 2026 22:34:50 +0800 Subject: [PATCH 2/2] revert Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/cache/teacache/extractors.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 088f41ad8f2..41ff4fa2719 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -219,10 +219,7 @@ def extract_qwen_context( # EXTRACT MODULATED INPUT (for cache decision) # ============================================================================ block = module.transformer_blocks[0] - # For zero_cond_t=True, temb has 2x batch size (doubled timestep). - # Use only the first half for the modulated input extraction. - temb_for_mod = temb[: temb.shape[0] // 2] if module.zero_cond_t else temb - img_mod_params = block.img_mod(temb_for_mod) + img_mod_params = block.img_mod(temb) img_mod1, _ = img_mod_params.chunk(2, dim=-1) img_modulated, _ = block.img_norm1(hidden_states, img_mod1) @@ -275,9 +272,7 @@ def run_transformer_blocks(): def postprocess(h): """Apply Qwen-specific output postprocessing.""" - # For zero_cond_t=True, temb has 2x batch size; use only the first half. - pp_temb = temb.chunk(2, dim=0)[0] if module.zero_cond_t else temb - h = module.norm_out(h, pp_temb) + h = module.norm_out(h, temb) output = module.proj_out(h) if not return_dict: return (output,)