[PERF] Wan2.2 support rmsnorm fused op#2583
Conversation
Signed-off-by: fan2956 <zhoufan53@huawei.com>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Signed-off-by: fan2956 <zhoufan53@huawei.com>
david6666666
left a comment
There was a problem hiding this comment.
I found one blocking correctness issue in the new RMSNorm path.
|
|
||
| def forward_cuda( | ||
| self, | ||
| x: torch.Tensor, |
There was a problem hiding this comment.
RMSNorm is called as self.norm_q(query) / self.norm_k(key), so on CUDA/HIP/XPU the custom-op dispatcher will invoke forward_cuda/forward_hip with only x. This new signature requires scale and shift, and then forwards them into forward_native(), which only accepts x. In practice, single-rank non-NPU runs will fail with TypeError before any inference starts.
| ) -> torch.Tensor: | ||
| return self.forward_native(x, scale, shift) | ||
|
|
||
| def forward_native( |
There was a problem hiding this comment.
will this affect the perf of cuda/xpu?
There was a problem hiding this comment.
I see the codes are same from main repo RMSNorm implementation, functionality wise, I validated on XPU, generated video looks good.
Perf wised, will submit follow up PR for XPU specific perf optimization on Wan2.2 later. Including plug with vllm-xpu-kernels.rms_norm
|
add perf regression test, cc @david6666666 Please check perf regression in cuda |
| ) -> torch.Tensor: | ||
| return self.forward_native(x, scale, shift) | ||
|
|
||
| def forward_hip( |
There was a problem hiding this comment.
I am getting this error
ERROR 04-09 09:12:25 [diffusion_worker.py:748] TypeError: RMSNorm.forward_hip() missing 2 required positional arguments: 'scale' and 'shift'
ERROR 04-09 09:12:25 [diffusion_worker.py:456] Error executing RPC: RMSNorm.forward_hip() missing 2 required positional arguments: 'scale' and 'shift'
|
@fan2956 can you share some benchmark values as this is perf related changes? Any hardware will do. |
|
@gcanlin @hsliuustc0106 Server command from this PR: Benchmark command:
full log |
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
done |
|
|
||
| # 2. Cross-attention | ||
| norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) | ||
| norm_hidden_states = self.norm2(hidden_states).type_as(hidden_states) |
There was a problem hiding this comment.
Dropping .float() here combined with the new LayerNorm (which doesn't upcast in forward_native) means cross-attn norm now runs in bf16 instead of fp32 on non-NPU. FP32LayerNorm upcast internally — the new LayerNorm does not. Silent numerical regression. Either upcast inside forward_native or restore hidden_states.float().
|
|
||
| return self.forward_native(x) | ||
|
|
||
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | |
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | |
| orig_dtype = x.dtype | |
| return super().forward(x.float()).to(orig_dtype) |
This is named as a drop-in for FP32LayerNorm but doesn't upcast. Match the original semantics.
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
…llm-omni into main_add_wan22_rmsnorm
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com>
gcanlin
left a comment
There was a problem hiding this comment.
I help fix the conflict and lint error. Please check again.
Gaohan123
left a comment
There was a problem hiding this comment.
Please supplement UT for it. Thanks
|
@Gaohan123 @hsliuustc0106 Please help merge force, CI is stuck for docs. |
Signed-off-by: fan2956 <zhoufan53@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: fan2956 <zhoufan53@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com>
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
add norm layer include layernorm and rmsnorm
Wan2.2 use fused layernorm ops and rmsnorm ops on NPU
Test Plan
curl
Test Result
before: 1.51s/step
after: 1.46s/step
before:
Request throughput (req/s): 0.06
Latency Mean (s): 16.3693
Latency Median (s): 16.3734
Latency P99 (s): 16.7011
Latency P95 (s): 16.7006
after:
Request throughput (req/s): 0.06
Latency Mean (s): 16.0344
Latency Median (s): 16.0346
Latency P99 (s): 16.0371
Latency P95 (s): 16.0368
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)