Skip to content

[PERF] Wan2.2 support rmsnorm fused op#2583

Merged
gcanlin merged 19 commits intovllm-project:mainfrom
fan2956:main_add_wan22_rmsnorm
Apr 16, 2026
Merged

[PERF] Wan2.2 support rmsnorm fused op#2583
gcanlin merged 19 commits intovllm-project:mainfrom
fan2956:main_add_wan22_rmsnorm

Conversation

@fan2956
Copy link
Copy Markdown
Contributor

@fan2956 fan2956 commented Apr 8, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

add norm layer include layernorm and rmsnorm
Wan2.2 use fused layernorm ops and rmsnorm ops on NPU

Test Plan

vllm serve /home/y00958577/Wan2.2-I2V-A14B-Diffusers/ \
  --omni \
  --port 8099 \
  --usp 8 \
  --use-hsdp \
  --enforce-eager \
  --log-stats \
  --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile", "torch_profiler_with_stack": "False"}'\
  --vae-patch-parallel-size 8 \
  --vae-use-tiling

curl

curl -X POST http://localhost:8099/v1/videos \
    -F "prompt='一只棕色野兔的正面特写镜头,采用低角度仰拍视角,营造亲密而庄严的视觉冲击。兔子一双圆润漆黑的大眼睛直视镜头深处,眼神中交织着野生动物的警觉与一>丝难以言喻的温柔好奇,仿佛在与观者建立跨越物种的静默对话。它毛色呈现层次丰富的棕褐渐变,从浅奶油色腹部过渡到深棕背部,每根毛发纹理清晰可辨,在侧光下泛着丝绸般>的光泽。细长洁白的胡须共有三对,随呼吸节奏微微颤动,偶尔因捕捉气流信息而轻轻摇摆。
两只标志性的长耳完全竖立,耳廓外侧覆盖短密棕毛,内侧则露出粉嫩的血管网络,薄如蝉翼的皮肤下血液流动隐约可见,耳朵以细微幅度不时转动,精准定位声源方向。背景是一>片澄澈的蔚蓝天空,形态蓬松的白色积云以缓慢速度横向漂移,云影在兔子头顶交替变化,光线随之明暗流转。晴朗天气的明媚阳光从画面左上方45度角倾泻而下,在兔脸右侧形成>柔和的伦勃朗式阴影,强化了面部立体感和皮毛质感。
兔子湿润的黑鼻子持续进行每秒三至四次的快速抽动,这是它们感知化学信号的本能动作,粉色三瓣嘴随之轻启,露出正在反刍的洁白门齿,下颌以稳定节奏左右研磨。摄影采用大>光圈浅景深,焦点牢牢锁定在兔子双眼连线所在的焦平面,背景天空和远景绿色植被虚化成圆润的彩色光斑,前景几根嫩绿草叶闯入画面边缘,以缓慢弧线随风摇曳,暗示着和煦的>春日微风。
整个场景弥漫着宁静致远的田园诗意,色彩温暖饱和,充满生命力。兔子在持续五秒的对视后,以典型 lagomorph 特征完成一次完整的瞬膜眨眼——第三眼睑从内侧横向滑过眼球,继
而缓缓歪头向右十五度,这个行为在动物行为学中代表认知加工和好奇心表达,耳朵随之向同一方向倾斜,最终恢复正视姿态,胡须舒展,完成这段短暂而珍贵的自然纪录。'" \
    -F "input_reference=@/home/zf/vllm-omni/rabbit.jpeg" \
    -F "size=832x480" \
    -F "seconds=5" \
    -F "fps=12" \
    -F "num_frames=81" \
    -F "guidance_scale=1" \
    -F "guidance_scale_2=1" \
    -F "flow_shift=5.0" \
    -F "num_inference_steps=4" \
    -F "seed=42"

Test Result

  1. DiT infer time

before: 1.51s/step
after: 1.46s/step

  1. benchmark
python3 benchmarks/diffusion/diffusion_benchmark_serving.py \
        --backend v1/videos \
        --dataset vbench \
        --task i2v \
        --num-prompts 10 \
        --height 832 \
        --width 480 \
        --fps 16 \
        --num-frames 81 \
        --num-inference-steps 4 \
        --port 8099

before:
Request throughput (req/s): 0.06
Latency Mean (s): 16.3693
Latency Median (s): 16.3734
Latency P99 (s): 16.7011
Latency P95 (s): 16.7006

after:

Request throughput (req/s): 0.06
Latency Mean (s): 16.0344
Latency Median (s): 16.0346
Latency P99 (s): 16.0371
Latency P95 (s): 16.0368

  1. Accuracy test

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: fan2956 <zhoufan53@huawei.com>
@fan2956 fan2956 requested a review from hsliuustc0106 as a code owner April 8, 2026 08:18
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

Signed-off-by: fan2956 <zhoufan53@huawei.com>
Copy link
Copy Markdown
Collaborator

@david6666666 david6666666 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found one blocking correctness issue in the new RMSNorm path.


def forward_cuda(
self,
x: torch.Tensor,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RMSNorm is called as self.norm_q(query) / self.norm_k(key), so on CUDA/HIP/XPU the custom-op dispatcher will invoke forward_cuda/forward_hip with only x. This new signature requires scale and shift, and then forwards them into forward_native(), which only accepts x. In practice, single-rank non-NPU runs will fail with TypeError before any inference starts.

@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented Apr 9, 2026

cc @tjtanaa @xuechendi @ZJY0516

) -> torch.Tensor:
return self.forward_native(x, scale, shift)

def forward_native(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this affect the perf of cuda/xpu?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the codes are same from main repo RMSNorm implementation, functionality wise, I validated on XPU, generated video looks good.
Perf wised, will submit follow up PR for XPU specific perf optimization on Wan2.2 later. Including plug with vllm-xpu-kernels.rms_norm

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

add perf regression test, cc @david6666666 Please check perf regression in cuda

) -> torch.Tensor:
return self.forward_native(x, scale, shift)

def forward_hip(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am getting this error

ERROR 04-09 09:12:25 [diffusion_worker.py:748] TypeError: RMSNorm.forward_hip() missing 2 required positional arguments: 'scale' and 'shift'
ERROR 04-09 09:12:25 [diffusion_worker.py:456] Error executing RPC: RMSNorm.forward_hip() missing 2 required positional arguments: 'scale' and 'shift'

@tjtanaa
Copy link
Copy Markdown
Contributor

tjtanaa commented Apr 9, 2026

@fan2956 can you share some benchmark values as this is perf related changes? Any hardware will do.

@tjtanaa
Copy link
Copy Markdown
Contributor

tjtanaa commented Apr 9, 2026

@gcanlin @hsliuustc0106
On mi300x, before and after

Server command from this PR:

#!/bin/bash
vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--omni \
--port 8099 \
--usp 8 \
--use-hsdp \
--vae-patch-parallel-size 8 \
--vae-use-tiling \
--log-stats

Benchmark command:

python3 main_add_wan22_rmsnorm/benchmarks/diffusion/diffusion_benchmark_serving.py \
        --backend v1/videos --dataset vbench --task t2v --num-prompts 10 \
        --height 480 --width 640 --fps 16 --num-frames 80 --port 8099
Metric Before After Change % Improvement
Benchmark Duration 382.83s 379.16s -3.67s 0.96% faster
Latency Mean 38.28s 37.92s -0.37s 0.96% faster
Latency Median 38.27s 36.08s -2.19s 5.72% faster
Latency P95 38.45s 44.85s +6.40s 16.6% slower ⚠️
Latency P99 38.52s 49.24s +10.72s 27.8% slower ⚠️
full log

Before

================= Serving Benchmark Result =================
Backend:                                 v1/videos      
Model:                                   default        
Dataset:                                 vbench         
Task:                                    t2v            
--------------------------------------------------
Benchmark duration (s):                  382.83         
Request rate:                            inf            
Max request concurrency:                 1              
Successful requests:                     10/10             
--------------------------------------------------
Request throughput (req/s):              0.03           
Latency Mean (s):                        38.2826        
Latency Median (s):                      38.2727        
Latency P99 (s):                         38.5191        
Latency P95 (s):                         38.4547        

============================================================



After


================= Serving Benchmark Result =================
Backend:                                 v1/videos      
Model:                                   default        
Dataset:                                 vbench         
Task:                                    t2v            
--------------------------------------------------
Benchmark duration (s):                  379.16         
Request rate:                            inf            
Max request concurrency:                 1              
Successful requests:                     10/10             
--------------------------------------------------
Request throughput (req/s):              0.03           
Latency Mean (s):                        37.9153        
Latency Median (s):                      36.0780        
Latency P99 (s):                         49.2427        
Latency P95 (s):                         44.8521        

============================================================

fan2956 added 4 commits April 9, 2026 20:55
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
@fan2956
Copy link
Copy Markdown
Contributor Author

fan2956 commented Apr 11, 2026

@fan2956 can you share some benchmark values as this is perf related changes? Any hardware will do.

done


# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
norm_hidden_states = self.norm2(hidden_states).type_as(hidden_states)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropping .float() here combined with the new LayerNorm (which doesn't upcast in forward_native) means cross-attn norm now runs in bf16 instead of fp32 on non-NPU. FP32LayerNorm upcast internally — the new LayerNorm does not. Silent numerical regression. Either upcast inside forward_native or restore hidden_states.float().


return self.forward_native(x)

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
return super().forward(x.float()).to(orig_dtype)

This is named as a drop-in for FP32LayerNorm but doesn't upcast. Match the original semantics.

@gcanlin gcanlin self-assigned this Apr 11, 2026
fan2956 added 4 commits April 13, 2026 11:40
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
@david6666666 david6666666 added the nightly-test label to trigger buildkite nightly test CI label Apr 13, 2026
Comment thread vllm_omni/diffusion/layers/norm.py Outdated
Comment thread vllm_omni/diffusion/layers/norm.py Outdated
fan2956 added 2 commits April 14, 2026 15:25
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
@gcanlin gcanlin added ready label to trigger buildkite CI and removed nightly-test label to trigger buildkite nightly test CI labels Apr 15, 2026
gcanlin added 2 commits April 15, 2026 07:31
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I help fix the conflict and lint error. Please check again.

@Gaohan123 Gaohan123 added this to the v0.20.0 milestone Apr 15, 2026
Copy link
Copy Markdown
Collaborator

@Gaohan123 Gaohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please supplement UT for it. Thanks

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin enabled auto-merge (squash) April 16, 2026 02:56
@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented Apr 16, 2026

@Gaohan123 @hsliuustc0106 Please help merge force, CI is stuck for docs.

Comment thread vllm_omni/diffusion/layers/norm.py
@gcanlin gcanlin merged commit f1cb4eb into vllm-project:main Apr 16, 2026
8 checks passed
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants