[Perf] torch compile for dit and rope kernel#317
Conversation
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if dynamic_arg_dims is not None: | ||
| dims_map = {} | ||
| for arg_name, dims in dynamic_arg_dims.items(): | ||
| if isinstance(dims, int): | ||
| dims_map[arg_name] = [dims] |
There was a problem hiding this comment.
Drop dynamic dims listed as arrays
In dit_support_compile the map of dynamic dimensions is only populated when dynamic_arg_dims values are integers, so any entries passed as lists (e.g., the new dynamic_arg_dims on ZImageTransformerBlock for x, attn_mask, and freqs_cis) are silently ignored. The decorator therefore never marks those tensor dimensions as dynamic before calling the compiled forward, leaving torch.compile to assume fixed shapes; subsequent invocations with different sequence lengths will either recompile or fail against the full-graph contract instead of using the requested dynamic handling.
Useful? React with 👍 / 👎.
| cos = freqs_cis.real.squeeze(0).to(query.dtype) | ||
| sin = freqs_cis.imag.squeeze(0).to(query.dtype) | ||
| query = apply_rotary_emb(query, cos, sin) | ||
| key = apply_rotary_emb(key, cos, sin) |
There was a problem hiding this comment.
Rotary embedding ignores per-batch frequencies
ZImageAttention.forward now squeezes the batch dimension off freqs_cis and feeds the result directly into apply_rotary_emb, which expects a 2D [tokens, dim] tensor. For freqs_cis shaped [batch, tokens, dim] (the pad_sequence output), squeeze(0) leaves the batch dimension intact when batch>1, so the rotary kernel reads strides as if there were no batch and mixes data from different samples, producing incorrect positional rotation or invalid memory accesses for multi-sample batches.
Useful? React with 👍 / 👎.
| @@ -0,0 +1,136 @@ | |||
| import torch | |||
| import triton | |||
There was a problem hiding this comment.
do we need to add this dependency in pyproject?
There was a problem hiding this comment.
we will install triton automatically when we install vllm on cuda platform
There was a problem hiding this comment.
i see. how about NPU. shall we "try except" to avoid failure on NPU?
There was a problem hiding this comment.
Could we refer to this implementation in vllm? The more details of discussion about it is here.
There was a problem hiding this comment.
Yes, we could. But I don't want to directly use vllm's custom op. We need a seperate mechanism to dispatch
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: wangyu31577 <wangyu31577@hundsun.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)