Skip to content

[Perf] torch compile for dit and rope kernel#317

Merged
ZJY0516 merged 15 commits intovllm-project:mainfrom
ZJY0516:torch-compile
Dec 19, 2025
Merged

[Perf] torch compile for dit and rope kernel#317
ZJY0516 merged 15 commits intovllm-project:mainfrom
ZJY0516:torch-compile

Conversation

@ZJY0516
Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 commented Dec 15, 2025

Purpose

Test Plan

Test Result

python text_to_image.py --model Tongyi-MAI/Z-Image-Turbo --num_inference_steps 9

Generation Configuration:
  Model: Tongyi-MAI/Z-Image-Turbo
  Inference steps: 9
  Cache backend: None (no acceleration)
  Parallel configuration: ulysses_degree=1
  Image size: 1024x1024
z-image Time (s)
PR 4.5626
w/o torch compile 4.6430
main 5.4158
Generation Configuration:
  Model: Qwen/Qwen-Image
  Inference steps: 50
  Cache backend: None (no acceleration)
  Parallel configuration: ulysses_degree=1
  Image size: 1024x1024
Qwen-Image Time (s)
PR 61.4890
w/o torch compile 64.7892
main 65.8851

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 changed the title [WIP] torch compile for dit [Perf] torch compile for dit and rope kernel Dec 17, 2025
@ZJY0516 ZJY0516 marked this pull request as ready for review December 17, 2025 13:03
@ZJY0516 ZJY0516 requested a review from SamitHuang December 17, 2025 13:03
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/compile.py Outdated
Comment on lines +48 to +52
if dynamic_arg_dims is not None:
dims_map = {}
for arg_name, dims in dynamic_arg_dims.items():
if isinstance(dims, int):
dims_map[arg_name] = [dims]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Drop dynamic dims listed as arrays

In dit_support_compile the map of dynamic dimensions is only populated when dynamic_arg_dims values are integers, so any entries passed as lists (e.g., the new dynamic_arg_dims on ZImageTransformerBlock for x, attn_mask, and freqs_cis) are silently ignored. The decorator therefore never marks those tensor dimensions as dynamic before calling the compiled forward, leaving torch.compile to assume fixed shapes; subsequent invocations with different sequence lengths will either recompile or fail against the full-graph contract instead of using the requested dynamic handling.

Useful? React with 👍 / 👎.

Comment on lines +136 to +139
cos = freqs_cis.real.squeeze(0).to(query.dtype)
sin = freqs_cis.imag.squeeze(0).to(query.dtype)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Rotary embedding ignores per-batch frequencies

ZImageAttention.forward now squeezes the batch dimension off freqs_cis and feeds the result directly into apply_rotary_emb, which expects a 2D [tokens, dim] tensor. For freqs_cis shaped [batch, tokens, dim] (the pad_sequence output), squeeze(0) leaves the batch dimension intact when batch>1, so the rotary kernel reads strides as if there were no batch and mixes data from different samples, producing incorrect positional rotation or invalid memory accesses for multi-sample batches.

Useful? React with 👍 / 👎.

Comment thread vllm_omni/diffusion/diffusion_engine.py
Comment thread vllm_omni/diffusion/layers/rope.py Outdated
@@ -0,0 +1,136 @@
import torch
import triton
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add this dependency in pyproject?

Copy link
Copy Markdown
Member Author

@ZJY0516 ZJY0516 Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will install triton automatically when we install vllm on cuda platform

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. how about NPU. shall we "try except" to avoid failure on NPU?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refer to this implementation in vllm? The more details of discussion about it is here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could. But I don't want to directly use vllm's custom op. We need a seperate mechanism to dispatch

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py Outdated
@SamitHuang SamitHuang added the ready label to trigger buildkite CI label Dec 18, 2025
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 enabled auto-merge (squash) December 19, 2025 02:26
@ZJY0516 ZJY0516 disabled auto-merge December 19, 2025 02:26
@ZJY0516 ZJY0516 enabled auto-merge (squash) December 19, 2025 02:29
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 merged commit 5216cf4 into vllm-project:main Dec 19, 2025
6 checks passed
@ZJY0516 ZJY0516 mentioned this pull request Dec 19, 2025
5 tasks
yenuo26 pushed a commit to yenuo26/vllm-omni that referenced this pull request Dec 29, 2025
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: wangyu31577 <wangyu31577@hundsun.com>
princepride pushed a commit to princepride/vllm-omni that referenced this pull request Jan 10, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants