Skip to content

[Bugfix] Fix Qwen-Image SP and TeaCache incompatibility#2101

Merged
gcanlin merged 2 commits intovllm-project:mainfrom
wtomin:fix-teacache
Mar 23, 2026
Merged

[Bugfix] Fix Qwen-Image SP and TeaCache incompatibility#2101
gcanlin merged 2 commits intovllm-project:mainfrom
wtomin:fix-teacache

Conversation

@wtomin
Copy link
Copy Markdown
Collaborator

@wtomin wtomin commented Mar 23, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Solving #2092

Root Cause:

In the extract_qwen_context function, the TeaCache extractor directly invokes module.img_in and module.pos_embed, bypassing the image_rope_prepare module. This prevents the Sequence Parallel (SP) SequenceParallelSplitHook from being triggered, resulting in hidden_states not being properly sharded.

When SP is enabled:

The _sp_plan registers a SequenceParallelSplitHook on image_rope_prepare to shard hidden_states after the forward pass

Solution

  • Modify the extract_qwen_context function in vllm_omni/diffusion/cache/teacache/extractors.py:
  • Invoke image_rope_prepare: Replace direct calls to img_in and pos_embed to ensure the SP split hook is triggered
  • Invoke modulate_index_prepare: Replace direct timestep processing to ensure modulate_index is correctly sharded when zero_cond_t=True
  • Pass modulate_index: Pass modulate_index to the transformer blocks in run_transformer_blocks

This ensures TeaCache is properly compatible with Sequence Parallelism (both Ulysses-SP and Ring Attention).

Test Plan

# ulysses-sp
 python examples/offline_inference/text_to_image/text_to_image.py --model Qwen/Qwen-Image --ulysses-degree 2 --cache-backend tea_cache
# ring attention
python examples/offline_inference/text_to_image/text_to_image.py --model Qwen/Qwen-Image --ring-degree 2 --cache-backend tea_cache

Test Result

SP TeaCache e2e latency Image
usp=2 ON 2.37s qwen_image_output
ring=2 ON 2.78s qwen_image_output

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

wtomin added 2 commits March 23, 2026 22:32
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
@wtomin wtomin marked this pull request as ready for review March 23, 2026 14:46
@wtomin wtomin requested a review from hsliuustc0106 as a code owner March 23, 2026 14:46
@wtomin wtomin requested review from ZJY0516 and gcanlin March 23, 2026 14:46
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 23bda666b7

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

# For zero_cond_t=True: timestep is doubled, modulate_index is created and
# sharded by the SequenceParallelSplitHook on modulate_index_prepare so that
# its sequence dimension matches the already-sharded hidden_states.
timestep, modulate_index = module.modulate_index_prepare(timestep, img_shapes)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep TeaCache extractor aligned with zero_cond_t batching

For Qwen checkpoints with zero_cond_t=True (the image-edit variants), ModulateIndexPrepare.forward() doubles timestep here (qwen_image_transformer.py:141-154), so temb becomes 2 * batch. The rest of extract_qwen_context() still consumes that embedding as if it were batch: block.img_norm1(hidden_states, img_mod1) is still called without modulate_index, and postprocess() never chunks temb back down before module.norm_out, unlike QwenImageTransformer2DModel.forward() (qwen_image_transformer.py:1062-1065). With TeaCache enabled, Qwen edit models will therefore fail on the first forward with a batch-dimension mismatch instead of running the transformer.

Useful? React with 👍 / 👎.

@wtomin wtomin mentioned this pull request Mar 23, 2026
@SamitHuang SamitHuang added the ready label to trigger buildkite CI label Mar 23, 2026
Copy link
Copy Markdown
Collaborator

@SamitHuang SamitHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gcanlin gcanlin merged commit 5aef6b9 into vllm-project:main Mar 23, 2026
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants