Skip to content

[Perf] Optimize Wan2.2 rotary embedding#2393

Merged
david6666666 merged 2 commits intovllm-project:mainfrom
gcanlin:rope-perf
Apr 2, 2026
Merged

[Perf] Optimize Wan2.2 rotary embedding#2393
david6666666 merged 2 commits intovllm-project:mainfrom
gcanlin:rope-perf

Conversation

@gcanlin
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin commented Apr 1, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

This change optimizes Wan2.2 rotary embedding application on NPU.

Previously, apply_rotary_emb_wan used an empty_like output tensor followed by two strided slice assignments for even and odd channels. On NPU this introduced significant InplaceCopy/ViewCopy overhead in the DiT hot path.

This change rewrites the rotary application to:

  • compute the rotated even/odd components directly,
  • combine them with torch.stack(...),
  • restore the original layout with flatten(...).

The new implementation avoids the expensive strided writeback pattern and reduces memory movement in self-attention.

Why This Matters

RoPE is applied in every Wan2.2 self-attention block to both query and key, across all denoising steps. Because it sits on the critical path of DiT execution, even a small inefficiency in this function is amplified heavily in end-to-end latency.

Profiling before this change showed that apply_rotary_emb_wan was a major source of aclnnInplaceCopy, which then contributed to extra memory overhead and longer attention-side critical-path time.

Test Plan

python image_to_video.py \
     --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \
     --image cherry_blossom.jpg \
     --prompt "Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \
     --negative-prompt "<optional quality filter>" \
     --height 512 \
     --width 768 \
     --num-frames 49 \
     --guidance-scale 4.0 \
     --num-inference-steps 20 \
     --flow-shift 12.0 \
     --fps 16 \
     --output i2v_output.mp4 \
     --enable-layerwise-offload \
     --ulysses-degree 8 \
     --vae-patch-parallel-size 8 \
     --vae-use-tiling

Accuracy test

 pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model

Test Result

Code Time
main with vae encode parallel 44s
this PR 33s

Test passed:

https://buildkite.com/vllm/vllm-omni/builds/5621/steps/canvas

NFO 04-01 10:08:00 [diffusion_model_runner.py:212] Peak GPU memory (this request): 52.13 GB reserved, 47.96 GB allocated, 4.17 GB pool overhead (8.0%)
(APIServer pid=872) INFO:     127.0.0.1:47216 - "GET /v1/videos/video_gen_659bb4d7b3b846a1a96b195f625b3eab HTTP/1.1" 200 OK
(APIServer pid=872) INFO 04-01 10:08:02 [diffusion_engine.py:103] Generation completed successfully.
(APIServer pid=872) INFO 04-01 10:08:02 [diffusion_engine.py:136] Post-processing completed in 0.3137 seconds
(APIServer pid=872) INFO 04-01 10:08:02 [diffusion_engine.py:139] DiffusionEngine.step breakdown: preprocess=15.36 ms, add_req_and_wait=241417.97 ms, postprocess=313.72 ms, total=241747.76 ms
(APIServer pid=872) INFO 04-01 10:08:02 [omni_base.py:162] [Summary] {}
(APIServer pid=872) INFO 04-01 10:08:03 [serving_video.py:185] Video response encoding (MP4+base64): 797.28 ms
(APIServer pid=872) INFO:     127.0.0.1:39794 - "GET /v1/videos/video_gen_659bb4d7b3b846a1a96b195f625b3eab HTTP/1.1" 200 OK
(APIServer pid=872) INFO 04-01 10:08:03 [api_server.py:1927] Video request video_gen_659bb4d7b3b846a1a96b195f625b3eab persisted /tmp/storage/video_gen_659bb4d7b3b846a1a96b195f625b3eab.mp4 output file.
(APIServer pid=872) INFO:     127.0.0.1:39800 - "GET /v1/videos/video_gen_659bb4d7b3b846a1a96b195f625b3eab HTTP/1.1" 200 OK
(APIServer pid=872) INFO:     127.0.0.1:39814 - "GET /v1/videos/video_gen_659bb4d7b3b846a1a96b195f625b3eab/content HTTP/1.1" 200 OK
online_video_e2e_latency_s=246.515
PASSEDGPU cleanup disabled
tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py::test_wan22_i2v_serving_matches_diffusers_video_similarity
=== PRE-TEST GPU CLEANUP ===
GPU cleanup disabled
INFO 04-01 10:08:05 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin requested a review from hsliuustc0106 as a code owner April 1, 2026 02:34
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

could this help other models as well? cc @fhfuih @wtomin

@bjf-frz
Copy link
Copy Markdown
Contributor

bjf-frz commented Apr 1, 2026

Performance has seen a substantial improvement in the tests.
1 die on A100
I run with case like:
curl -s http://127.0.0.1:8000/v1/videos
-H "Accept: application/json"
-F "prompt=xxx。"
-F "input_reference=@coffee.png"
-F "width=832"
-F "height=480"
-F "fps=16"
-F "num_frames=81"
-F "guidance_scale=1.0"
-F "flow_shift=5.0"
-F "num_inference_steps=4"
-F "seed=42"

before modify:
'e2e_table': [{'request_id': 'video_gen_3ddc868cbfd0446b8c3adba73772c943',
'e2e_total_ms': 52573.21834564209,
'e2e_total_tokens': 0,
'transfers_total_time_ms': 0.0,
'transfers_total_kbytes': 0.0}]}

after modify:
'e2e_table': [{'request_id': 'video_gen_1f453bc939984ef9aefc66fa6ce24f08',
'e2e_total_ms': 44730.52096366882,
'e2e_total_tokens': 0,
'transfers_total_time_ms': 0.0,
'transfers_total_kbytes': 0.0}]}

55s -> 44s

@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 1, 2026

Thanks @bjf-frz for testing this PR on A100. Looks it's taking much optimization for Wan2.2. Let's wait nightly-test.

@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 1, 2026

could this help other models as well? cc @fhfuih @wtomin

I search apply_rotary_emb in the codebase. And found that there are kinds of implementation currently. I think we should try to unify one best implementation and every models reuse it.

@fhfuih
Copy link
Copy Markdown
Contributor

fhfuih commented Apr 1, 2026

could this help other models as well? cc @fhfuih @wtomin

I search apply_rotary_emb in the codebase. And found that there are kinds of implementation currently. I think we should try to unify one best implementation and every models reuse it.

I've completed the investigation. Here's what I found:

Summary

PR 2393 Purpose: Optimizes Wan2.2's apply_rotary_emb_wan function by replacing the inefficient torch.empty_like + strided slice assignment pattern with torch.stack + flatten, achieving a 25% speedup (44s → 33s).

Same Problem Found

helios_transformer.py:53-68 has the EXACT same anti-pattern:

out = torch.empty_like(hidden_states)
out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
return out.type_as(hidden_states)

Models Already Optimized

  • stable_audio - Uses torch.cat approach (line 55)
  • qwen_image - Uses torch.stack + flatten (line 225)
  • omnigen2 - Uses torch.stack + flatten (line 202)

VLLM Source Code Analysis

The VLLM core (~/workspace/VLLM/vllm/model_executor/layers/rotary_embedding/common.py:142-182) already uses the optimized pattern in ApplyRotaryEmb.forward_static:

  • Neox style: torch.cat((o1, o2), dim=-1)
  • GPT-J/interleaved: torch.stack((o1, o2), dim=-1).flatten(-2) ← Same as PR 2393

The apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary is a compiled CUDA kernel (not Python source), already optimized at the kernel level.

@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 1, 2026

@fhfuih Thanks for the detailed investigation! I think we can try the custom op ApplyRotaryEmb.

@fhfuih
Copy link
Copy Markdown
Contributor

fhfuih commented Apr 1, 2026

@fhfuih Thanks for the detailed investigation! I think we can try the custom op ApplyRotaryEmb.

No prob, it was AI anyway 😼

@gcanlin gcanlin removed the nightly-test label to trigger buildkite nightly test CI label Apr 1, 2026
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 2, 2026

@hsliuustc0106 @fhfuih @wtomin I create a RFC #2436 for unifying the implementation in other models and call help for community. I think this PR doesn't change too much code so that we can easily cherry-pick to v0.18.0-post1 if we want it.

@david6666666
Copy link
Copy Markdown
Collaborator

LGTM, Thanks to the RFC for optimizing and other related models, and to the Source Code Analysis.

@david6666666 david6666666 merged commit e2892ef into vllm-project:main Apr 2, 2026
8 checks passed
willamhou added a commit to willamhou/vllm-omni that referenced this pull request Apr 3, 2026
…memory (vllm-project#2436)

Replace strided slice assignment (`out[..., 0::2] = ...`) with
`torch.stack(..., dim=-1).flatten(-2, -1)` in Helios RoPE, matching the
optimization applied to Wan2.2 in PR vllm-project#2393 (25% speedup on NPU).

The stack+flatten pattern produces contiguous memory layout, avoiding
non-sequential write patterns that hurt GPU/NPU cache performance.
Math is identical — verified bit-exact across float32/float16/bfloat16.

This is Phase 1 of vllm-project#2436 (Helios only). Phase 2 (unified utility) to follow.

Signed-off-by: willamhou <willamhou@ceresman.com>
linyueqian pushed a commit to JuanPZuluaga/vllm-omni that referenced this pull request Apr 3, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
david6666666 pushed a commit to david6666666/vllm-omni that referenced this pull request Apr 7, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: David Chen <530634352@qq.com>
david6666666 added a commit that referenced this pull request Apr 7, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: David Chen <530634352@qq.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
vraiti pushed a commit to vraiti/vllm-omni that referenced this pull request Apr 9, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
willamhou added a commit to willamhou/vllm-omni that referenced this pull request Apr 19, 2026
…memory (vllm-project#2436)

Replace strided slice assignment (`out[..., 0::2] = ...`) with
`torch.stack(..., dim=-1).flatten(-2, -1)` in Helios RoPE, matching the
optimization applied to Wan2.2 in PR vllm-project#2393 (25% speedup on NPU).

The stack+flatten pattern produces contiguous memory layout, avoiding
non-sequential write patterns that hurt GPU/NPU cache performance.
Math is identical — verified bit-exact across float32/float16/bfloat16.

This is Phase 1 of vllm-project#2436 (Helios only). Phase 2 (unified utility) to follow.

Signed-off-by: willamhou <willamhou@ceresman.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants