[diffusion] kernel fusion: gated residual layernorm scale shift and layernorm scale shift kernel fusion for Qwen-Image, WAN and HunyuanVideo#14717
Conversation
Summary of ChangesHello @jianyingzhu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of the Qwen-Image and WAN diffusion models by implementing kernel fusion for critical operations. By combining residual connections, gating, LayerNormalization, and scale/shift into highly optimized CUDA kernels, the changes aim to reduce GPU overhead and improve overall efficiency, while maintaining numerical precision across various floating-point formats. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces new fused CUDA kernels for Layer Normalization (LN) combined with scale and shift operations, and also a fused residual connection, gating, LN, and scale/shift. The changes involve adding new C++ CUDA kernel files (fused_layernorm_scale_shift.cu), registering these kernels in common_extension.cc and sgl_kernel_ops.h, and exposing them via Python wrappers in sgl_kernel/elementwise.py. The layernorm.py file is updated to leverage these new fused kernels within its LayerNormScaleShift and ScaleResidualLayerNormScaleShift classes, including logic to handle various tensor dimensions for scale, shift, and gate parameters, and a fallback to non-fused operations. The qwen_image.py model is modified to utilize the new LayerNormScaleShift class, replacing existing LayerNorm instances and adapting the forward pass to explicitly pass scale and shift parameters. Additionally, text_encoding.py gains a try-except block to gracefully handle Hugging Face encoder-decoder models. Review comments highlighted several issues: an incorrect argument order and missing parameters in a call to fused_scale_residual_layernorm_scale_shift in layernorm.py, overly complex and potentially incorrect dimension handling for scale_blc and shift_blc in both LayerNormScaleShift and ScaleResidualLayerNormScaleShift classes, a mismatch in 3D gate tensor handling between Python and C++ kernel expectations, redundant if blocks for 4D scale/shift in LayerNormScaleShift, and unused debug lines. Furthermore, the bfloat16 accuracy test tolerance was noted as being too high.
sgl-kernel/csrc/sgl_diffusion/elementwise/fused_layernorm_scale_shift.cu
Outdated
Show resolved
Hide resolved
sgl-kernel/tests/sgl_diffusion/test_fused_layernorm_scale_shift.py
Outdated
Show resolved
Hide resolved
|
Hi, I noticed today that this PR and my PR do similar things. You additionally implemented the fusion for layernorm_scale_shift, while I added support for RMS Norm in my PR. There are also some minor differences — for example, I used the Welford algorithm, and you used shared memory (smem) to temporarily store x. I think our PRs could be merged? |
|
Need add this kernel test/benchmark in sgl-kernel |
There is indeed some overlap. Our cuda kernel code is based on CUTLASS’s LayerNorm. We can discuss it further. |
Thank you, the benchmark has been added and the results are shown in the PR. |
|
Some high-level suggestions: The operation of this kernel isn’t particularly complex. I noticed the initial commit seemed to have a CUDA DSL implementation—why was the final version switched to raw .cu? Also, the code length for such a simple kernel seems excessive. Do we really need this much dispatch logic? Perhaps we could start by writing a more generalized kernel first. Also, could you share the micro-benchmark results for this kernel? |
well, we found that sgl-kernel currently doesn’t have any kernels written with CuTeDSL, so we chose a style that stays more consistent with the existing code. The complicated kernel launch logic mainly comes from different models having different input format requirements, for example: 2D/3D inputs, w/wo affine, and w/wo gate residual. |
Ok, sounds good. |
Would it be possible to discuss the details further? You’re very welcome to join the Sglang Slack and DM me (Yihan Chen), or feel free to share another contact method if that’s more convenient for you. @jianyingzhu |
ad462fe to
588252d
Compare
|
@mickqian It's time to merge this optimization. https://github.com/sgl-project/sglang/actions/runs/21616947048/job/62427336509?pr=14717 |
|
Congrats and huge thanks every participant, for the fantastic collaboration on this one 🎉 |
Motivation
Profiling shows a lot of GPU bubbles, which can be mitigated via kernel fusion.
This PR:
fused_scale_residual_layernorm_scale_shiftfused_layernorm_scale_shift.The kernel fusion reduces kernel launch overhead for Qwen-Image, WAN and HunyuanVideo pipelines.
The fused kernels support both 2D ([batch_size * seq_len, inner_dim]), ([1, inner_dim]), 3D ([1, 1, inner_dim]), and 4D ([batch_size, num_frames, 1, inner_dim]) modulation for diffusion models. Numerical parity is maintained with the existing PyTorch/Triton implementations across fp32, fp16, and bf16.
Modifications
sgl_kernel, add thefused_scale_residual_layernorm_scale_shiftandfused_layernorm_scale_shiftCUDA kernels based on CUTLASS.layers/layernorm.py, update the LayerNorm path to call the fused kernels when available.layernorm.py.with @AichenF and @yingluosanqian.
Special thanks to @yingluosanqian, he helped a lot.
Pass unit tests.

Benchmarking and Profiling
Benchmark: fused_norm_scale_shift
Benchmark: fused_scale_residual_norm_scale_shift
Profiling
Qwen
Command:
1. High-level Summary
2. Stage Breakdown
e560ec78d6a0fe47e354c826d40aa0881178771de560ec78d6a0fe47e354c826d40aa0881178771dWan2.2
Command
1. High-level Summary
2. Stage Breakdown
e560ec78d6a0fe47e354c826d40aa0881178771de560ec78d6a0fe47e354c826d40aa0881178771dWan-AI/Wan2.1-T2V-1.3B-Diffusers
1. High-level Summary
2. Stage Breakdown
beff96c6874e80c47d31dfe6eb64fc4ef1b1ae89beff96c6874e80c47d31dfe6eb64fc4ef1b1ae89Hunyuan
Command
1. High-level Summary
2. Stage Breakdown
767f55e14a732e1ab763323ff97ba38c7492c7ec767f55e14a732e1ab763323ff97ba38c7492c7ecChecklist