Skip to content

[diffusion] kernel fusion: gated residual layernorm scale shift and layernorm scale shift kernel fusion for Qwen-Image, WAN and HunyuanVideo#14717

Merged
mickqian merged 114 commits intosgl-project:mainfrom
AichenF:feat/layernorm_scale_shift_kernel
Feb 4, 2026
Merged

[diffusion] kernel fusion: gated residual layernorm scale shift and layernorm scale shift kernel fusion for Qwen-Image, WAN and HunyuanVideo#14717
mickqian merged 114 commits intosgl-project:mainfrom
AichenF:feat/layernorm_scale_shift_kernel

Conversation

@jianyingzhu
Copy link
Contributor

@jianyingzhu jianyingzhu commented Dec 9, 2025

Motivation

Profiling shows a lot of GPU bubbles, which can be mitigated via kernel fusion.

This PR:

  1. fused residual, gating, LayerNorm, and scale/shift into one single CUDA kernel named fused_scale_residual_layernorm_scale_shift
  2. fused LayerNorm, and scale/shift into one single CUDA kernel named fused_layernorm_scale_shift.

The kernel fusion reduces kernel launch overhead for Qwen-Image, WAN and HunyuanVideo pipelines.

The fused kernels support both 2D ([batch_size * seq_len, inner_dim]), ([1, inner_dim]), 3D ([1, 1, inner_dim]), and 4D ([batch_size, num_frames, 1, inner_dim]) modulation for diffusion models. Numerical parity is maintained with the existing PyTorch/Triton implementations across fp32, fp16, and bf16.

Modifications

  • In sgl_kernel, add the fused_scale_residual_layernorm_scale_shift and fused_layernorm_scale_shift CUDA kernels based on CUTLASS.
  • In layers/layernorm.py, update the LayerNorm path to call the fused kernels when available.
  • In the DiT models, update the Qwen-Image, WAN, and HunyuanVideo implementations to use the fused kernels in layernorm.py.
  • Support both Layernorm and RMSNorm.

with @AichenF and @yingluosanqian.
Special thanks to @yingluosanqian, he helped a lot.

Pass unit tests.
image

Benchmarking and Profiling

Benchmark: fused_norm_scale_shift

B S D norm_type affine SGLang Native CuTeDSL CUDA C
1 128 1024 layer True 31.680000 7.616000 7.584
1 128 1024 layer False 22.208000 7.536000 7.584
1 128 1024 rms True 12.128000 7.328000 7.296
1 128 1024 rms False 12.192000 7.104000 7.520
1 128 3072 layer True 36.736000 8.352000 8.640
1 128 3072 layer False 25.216000 8.352000 8.864
1 128 3072 rms True 12.928000 8.288000 8.320
1 128 3072 rms False 12.960000 8.096000 8.352
1 128 4096 layer True 38.495999 8.896000 9.056
1 128 4096 layer False 26.912000 8.640000 9.248
1 128 4096 rms True 13.728000 8.640000 8.960
1 128 4096 rms False 13.760000 8.672000 8.928
1 1024 1024 layer True 38.176000 9.952000 10.048
1 1024 1024 layer False 27.200000 9.888000 9.792
1 1024 1024 rms True 14.976000 9.696000 9.728
1 1024 1024 rms False 15.008000 9.504000 9.472
1 1024 3072 layer True 56.976002 17.696001 16.448
1 1024 3072 layer False 46.048000 16.256001 16.736001
1 1024 3072 rms True 23.360001 15.807999 16.000001
1 1024 3072 rms False 23.424000 16.000001 16.031999
1 1024 4096 layer True 69.023997 20.000000 18.784
1 1024 4096 layer False 56.288000 18.816000 18.751999
1 1024 4096 rms True 27.136000 18.656000 18.464001
1 1024 4096 rms False 27.200000 18.656000 18.688001
1 4096 1024 layer True 69.792002 19.072000 18.688001
1 4096 1024 layer False 58.304001 18.464001 18.656
1 4096 1024 rms True 26.496001 18.432001 18.624
1 4096 1024 rms False 26.496001 18.624000 18.432001
1 4096 3072 layer True 172.128007 48.448000 43.935999
1 4096 3072 layer False 157.215998 42.144001 43.744002
1 4096 3072 rms True 62.080000 43.168001 41.375998
1 4096 3072 rms False 62.047999 43.168001 41.600000
1 4096 4096 layer True 225.408003 56.208000 52.544001
1 4096 4096 layer False 212.576002 52.576002 52.384000
1 4096 4096 rms True 80.672003 52.607998 51.968001
1 4096 4096 rms False 80.863997 52.512001 52.223999

Benchmark: fused_scale_residual_norm_scale_shift

B S D norm_type affine SGLang Native CuTeDSL CUDA C
1 128 1024 layer True 40.672000 7.968000 8.256
1 128 1024 layer False 30.848000 7.872000 8.320
1 128 1024 rms True 20.288000 7.936000 7.840
1 128 1024 rms False 20.256000 7.936000 7.744
1 128 3072 layer True 45.744000 9.280000 9.312
1 128 3072 layer False 34.143999 8.992000 9.472
1 128 3072 rms True 21.280000 8.896000 9.088
1 128 3072 rms False 21.376001 8.736000 9.088
1 128 4096 layer True 47.680002 9.888000 10.080
1 128 4096 layer False 35.792001 9.856000 10.048
1 128 4096 rms True 21.856001 9.536000 9.504
1 128 4096 rms False 21.952000 9.536000 9.696
1 1024 1024 layer True 50.528001 11.552000 11.616
1 1024 1024 layer False 38.624000 11.296000 11.488
1 1024 1024 rms True 24.768000 11.360000 11.488
1 1024 1024 rms False 24.863999 11.616000 11.264
1 1024 3072 layer True 78.047998 21.248000 20.256
1 1024 3072 layer False 64.896002 20.416001 20.384001
1 1024 3072 rms True 40.288001 20.191999 20.384001
1 1024 3072 rms False 40.160000 20.191999 20.160001
1 1024 4096 layer True 90.976000 25.408000 24.831999
1 1024 4096 layer False 78.432001 24.863999 24.831999
1 1024 4096 rms True 48.000000 24.960000 24.768
1 1024 4096 rms False 48.000000 24.928000 24.784001
1 4096 1024 layer True 93.583997 25.472000 25.087999
1 4096 1024 layer False 80.672003 24.896000 24.896
1 4096 1024 rms True 48.560001 24.863999 24.800001
1 4096 1024 rms False 48.416000 24.928000 24.831999
1 4096 3072 layer True 238.976002 59.424002 57.983998
1 4096 3072 layer False 224.352002 57.952002 57.952002
1 4096 3072 rms True 126.271993 57.920001 57.888001
1 4096 3072 rms False 124.991998 57.663999 57.920001
1 4096 4096 layer True 314.336002 74.047998 74.207999
1 4096 4096 layer False 300.495997 74.143998 73.919997
1 4096 4096 rms True 166.207999 73.919997 74.047998
1 4096 4096 rms False 166.639999 74.175999 73.760003

Profiling

Qwen

Command:

env SGLANG_CACHE_DIT_ENABLED=true SGLANG_CACHE_DIT_FN=1 SGLANG_CACHE_DIT_BN=0 SGLANG_CACHE_DIT_WARMUP=4 SGLANG_CACHE_DIT_RDT=0.24 SGLANG_CACHE_DIT_MC=3 SGLANG_CACHE_DIT_TAYLORSEER=false SGLANG_CACHE_DIT_TS_ORDER=1 SGLANG_CACHE_DIT_SCM_PRESET=none SGLANG_CACHE_DIT_SCM_POLICY=dynamic sglang generate --model-path=Qwen/Qwen-Image-2512 --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" '--negative-prompt= ' --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 --seed=42 --save-output --enable-torch-compile --warmup --dit-cpu-offload false --text-encoder-cpu-offload false

1. High-level Summary

Metric Baseline New Diff
E2E Latency 7072.14 ms 6289.26 ms -782.89 ms (-11.1%)
Throughput 0.14 req/s 0.16 req/s -

2. Stage Breakdown

Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%)
InputValidationStage 0.05 0.06 +0.01 +19.7%
TextEncodingStage 48.04 47.10 -0.94 -2.0%
ConditioningStage 0.01 0.02 +0.00 +12.5%
TimestepPreparationStage 0.66 0.81 +0.15 +22.7%
LatentPreparationStage 0.18 0.22 +0.04 +19.1%
DenoisingStage 7008.61 6223.78 -784.82 -11.2%
DecodingStage 12.41 15.01 +2.60 +21.0%
  • Baseline Commit: e560ec78d6a0fe47e354c826d40aa0881178771d
  • New Commit: e560ec78d6a0fe47e354c826d40aa0881178771d

Wan2.2

Command

sglang generate --model-path=Wan-AI/Wan2.2-T2V-A14B-Diffusers --log-level=info --prompt="A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." --negative-prompt=" " --720p --num-inference-steps=40 --num-frames=81 --guidance-scale=5.0 --seed=42 --save-output --num-gpus=4 --enable-cfg-parallel --ulysses-degree=2 --dit-layerwise-offload true --dit-cpu-offload false --vae-cpu-offload false --text-encoder-cpu-offload true --warmup --enable-torch-compile true

1. High-level Summary

Metric Baseline New Diff
E2E Latency 362312.77 ms 357770.03 ms -4542.74 ms (-1.3%)
Throughput 0.00 req/s 0.00 req/s -

2. Stage Breakdown

Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%)
InputValidationStage 0.06 0.06 +0.00 +2.3%
TextEncodingStage 1165.91 1219.68 +53.77 +4.6%
ConditioningStage 0.02 0.02 -0.00 -4.4%
TimestepPreparationStage 0.62 0.57 -0.05 -8.7%
LatentPreparationStage 0.20 0.20 +0.00 +0.1%
DenoisingStage 352230.54 343721.09 -8509.45 -2.4%
DecodingStage 8172.35 12084.46 +3912.11 +47.9%
  • Baseline Commit: e560ec78d6a0fe47e354c826d40aa0881178771d
  • New Commit: e560ec78d6a0fe47e354c826d40aa0881178771d

Wan-AI/Wan2.1-T2V-1.3B-Diffusers

1. High-level Summary

Metric Baseline New Diff
E2E Latency 71606.89 ms 61612.73 ms -9994.15 ms (-14.0%)
Throughput 0.01 req/s 0.02 req/s -

2. Stage Breakdown

Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%)
InputValidationStage 0.06 0.03 -0.03 -44.8%
TextEncodingStage 1164.18 1150.55 -13.63 -1.2%
ConditioningStage 0.02 0.01 -0.01 -48.9%
TimestepPreparationStage 0.32 0.23 -0.09 -27.9%
LatentPreparationStage 0.12 0.08 -0.04 -33.7%
DenoisingStage 67506.99 57052.85 -10454.15 -15.5%
DecodingStage 2933.01 3407.29 +474.28 +16.2%
  • Baseline Commit: beff96c6874e80c47d31dfe6eb64fc4ef1b1ae89
  • New Commit: beff96c6874e80c47d31dfe6eb64fc4ef1b1ae89

Hunyuan

Command

env SGLANG_CACHE_DIT_ENABLED=true SGLANG_CACHE_DIT_FN=1 SGLANG_CACHE_DIT_BN=0 SGLANG_CACHE_DIT_WARMUP=4 SGLANG_CACHE_DIT_RDT=0.24 SGLANG_CACHE_DIT_MC=3 SGLANG_CACHE_DIT_TAYLORSEER=false SGLANG_CACHE_DIT_TS_ORDER=1 SGLANG_CACHE_DIT_SCM_PRESET=none SGLANG_CACHE_DIT_SCM_POLICY=dynamic sglang generate --model-path hunyuanvideo-community/HunyuanVideo --text-encoder-cpu-offload --pin-cpu-memory --prompt "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." --save-output --num-frames 65 --width 848 --height 480 --num-inference-steps 30 --seed=42 --save-output --warmup --enable-torch-compile true

1. High-level Summary

Metric Baseline New Diff
E2E Latency 71280.41 ms 70908.72 ms -371.69 ms (-0.5%)
Throughput 0.01 req/s 0.01 req/s -

2. Stage Breakdown

Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%)
InputValidationStage 0.05 0.07 +0.01 +26.7%
TextEncodingStage 310.40 309.80 -0.60 -0.2%
ConditioningStage 0.02 0.02 -0.00 -3.9%
TimestepPreparationStage 0.48 0.48 -0.00 -1.0%
LatentPreparationStage 0.17 0.15 -0.01 -6.4%
DenoisingStage 53813.70 53277.02 -536.68 -1.0%
DecodingStage 17153.22 17318.78 +165.56 +1.0%
  • Baseline Commit: 767f55e14a732e1ab763323ff97ba38c7492c7ec
  • New Commit: 767f55e14a732e1ab763323ff97ba38c7492c7ec

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @jianyingzhu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the Qwen-Image and WAN diffusion models by implementing kernel fusion for critical operations. By combining residual connections, gating, LayerNormalization, and scale/shift into highly optimized CUDA kernels, the changes aim to reduce GPU overhead and improve overall efficiency, while maintaining numerical precision across various floating-point formats.

Highlights

  • New Fused CUDA Kernels: Introduced fused_scale_residual_layernorm_scale_shift and fused_layernorm_scale_shift CUDA kernels to combine multiple operations into single, optimized GPU kernels.
  • Performance Optimization: These kernels are designed to mitigate GPU bubbles, reduce memory pressure, and lower kernel launch overhead in Qwen-Image and WAN pipelines.
  • Flexible Modulation Support: The fused kernels support both per-token 2D modulation ([M,N]) and per-(batch, frame) 4D modulation ([B,F,1,N]) for video diffusion models.
  • Optional Gating Mechanisms: Provides optional gating with various memory layouts to cover common patterns in diffusion backbones and cross-attention blocks.
  • Numerical Parity: Ensures numerical accuracy is maintained with existing PyTorch/Triton implementations across fp32, fp16, and bf16 data types.
  • Integration into Models: Updated the LayerNorm path in layernorm.py and specifically the Qwen-Image and WAN model implementations to utilize these new fused kernels.
  • Robustness for Text Encoding: Added a fallback mechanism in the text encoding stage to handle native Hugging Face encoder-decoder models that might not conform to the custom encoder path.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces new fused CUDA kernels for Layer Normalization (LN) combined with scale and shift operations, and also a fused residual connection, gating, LN, and scale/shift. The changes involve adding new C++ CUDA kernel files (fused_layernorm_scale_shift.cu), registering these kernels in common_extension.cc and sgl_kernel_ops.h, and exposing them via Python wrappers in sgl_kernel/elementwise.py. The layernorm.py file is updated to leverage these new fused kernels within its LayerNormScaleShift and ScaleResidualLayerNormScaleShift classes, including logic to handle various tensor dimensions for scale, shift, and gate parameters, and a fallback to non-fused operations. The qwen_image.py model is modified to utilize the new LayerNormScaleShift class, replacing existing LayerNorm instances and adapting the forward pass to explicitly pass scale and shift parameters. Additionally, text_encoding.py gains a try-except block to gracefully handle Hugging Face encoder-decoder models. Review comments highlighted several issues: an incorrect argument order and missing parameters in a call to fused_scale_residual_layernorm_scale_shift in layernorm.py, overly complex and potentially incorrect dimension handling for scale_blc and shift_blc in both LayerNormScaleShift and ScaleResidualLayerNormScaleShift classes, a mismatch in 3D gate tensor handling between Python and C++ kernel expectations, redundant if blocks for 4D scale/shift in LayerNormScaleShift, and unused debug lines. Furthermore, the bfloat16 accuracy test tolerance was noted as being too high.

@jianyingzhu jianyingzhu changed the title [diffusion] kernel fusion: fused_scale_residual_layernorm_scale_shift and fused_layernorm_scale_shift [diffusion] kernel fusion: add layernorm scale shift fusion for Wan and Qwen-Image models Dec 9, 2025
@jianyingzhu jianyingzhu changed the title [diffusion] kernel fusion: add layernorm scale shift fusion for Wan and Qwen-Image models [diffusion] kernel fusion: add gated residual layernorm scale shift and layernorm scale shift fusion for Qwen-Image, WAN and HunyuanVideo Dec 10, 2025
@jianyingzhu jianyingzhu changed the title [diffusion] kernel fusion: add gated residual layernorm scale shift and layernorm scale shift fusion for Qwen-Image, WAN and HunyuanVideo [diffusion] kernel fusion: gated residual layernorm scale shift and layernorm scale shift kernel fusion for Qwen-Image, WAN and HunyuanVideo Dec 10, 2025
@yingluosanqian
Copy link
Collaborator

Hi, I noticed today that this PR and my PR do similar things.

You additionally implemented the fusion for layernorm_scale_shift, while I added support for RMS Norm in my PR. There are also some minor differences — for example, I used the Welford algorithm, and you used shared memory (smem) to temporarily store x.

I think our PRs could be merged?

@FlamingoPg
Copy link
Collaborator

Need add this kernel test/benchmark in sgl-kernel

@jianyingzhu
Copy link
Contributor Author

Hi, I noticed today that this PR and my PR do similar things.

You additionally implemented the fusion for layernorm_scale_shift, while I added support for RMS Norm in my PR. There are also some minor differences — for example, I used the Welford algorithm, and you used shared memory (smem) to temporarily store x.

I think our PRs could be merged?

There is indeed some overlap. Our cuda kernel code is based on CUTLASS’s LayerNorm. We can discuss it further.

@jianyingzhu
Copy link
Contributor Author

Need add this kernel test/benchmark in sgl-kernel

Thank you, the benchmark has been added and the results are shown in the PR.

@BBuf
Copy link
Collaborator

BBuf commented Dec 11, 2025

Some high-level suggestions: The operation of this kernel isn’t particularly complex. I noticed the initial commit seemed to have a CUDA DSL implementation—why was the final version switched to raw .cu? Also, the code length for such a simple kernel seems excessive. Do we really need this much dispatch logic? Perhaps we could start by writing a more generalized kernel first. Also, could you share the micro-benchmark results for this kernel?

@AichenF
Copy link
Contributor

AichenF commented Dec 12, 2025

Some high-level suggestions: The operation of this kernel isn’t particularly complex. I noticed the initial commit seemed to have a CUDA DSL implementation—why was the final version switched to raw .cu? Also, the code length for such a simple kernel seems excessive. Do we really need this much dispatch logic? Perhaps we could start by writing a more generalized kernel first. Also, could you share the micro-benchmark results for this kernel?

well, we found that sgl-kernel currently doesn’t have any kernels written with CuTeDSL, so we chose a style that stays more consistent with the existing code. The complicated kernel launch logic mainly comes from different models having different input format requirements, for example: 2D/3D inputs, w/wo affine, and w/wo gate residual.

@BBuf
Copy link
Collaborator

BBuf commented Dec 12, 2025

Some high-level suggestions: The operation of this kernel isn’t particularly complex. I noticed the initial commit seemed to have a CUDA DSL implementation—why was the final version switched to raw .cu? Also, the code length for such a simple kernel seems excessive. Do we really need this much dispatch logic? Perhaps we could start by writing a more generalized kernel first. Also, could you share the micro-benchmark results for this kernel?

well, we found that sgl-kernel currently doesn’t have any kernels written with CuTeDSL, so we chose a style that stays more consistent with the existing code. The complicated kernel launch logic mainly comes from different models having different input format requirements, for example: 2D/3D inputs, w/wo affine, and w/wo gate residual.

Ok, sounds good.

@yingluosanqian
Copy link
Collaborator

Hi, I noticed today that this PR and my PR do similar things.
You additionally implemented the fusion for layernorm_scale_shift, while I added support for RMS Norm in my PR. There are also some minor differences — for example, I used the Welford algorithm, and you used shared memory (smem) to temporarily store x.
I think our PRs could be merged?

There is indeed some overlap. Our cuda kernel code is based on CUTLASS’s LayerNorm. We can discuss it further.

Would it be possible to discuss the details further? You’re very welcome to join the Sglang Slack and DM me (Yihan Chen), or feel free to share another contact method if that’s more convenient for you. @jianyingzhu

@github-actions github-actions bot added documentation Improvements or additions to documentation quant LLM Quantization amd labels Feb 3, 2026
@yingluosanqian yingluosanqian force-pushed the feat/layernorm_scale_shift_kernel branch from ad462fe to 588252d Compare February 3, 2026 04:16
@yingluosanqian yingluosanqian removed documentation Improvements or additions to documentation quant LLM Quantization amd labels Feb 3, 2026
@BBuf
Copy link
Collaborator

BBuf commented Feb 4, 2026

@mickqian It's time to merge this optimization. https://github.com/sgl-project/sglang/actions/runs/21616947048/job/62427336509?pr=14717

@mickqian
Copy link
Collaborator

mickqian commented Feb 4, 2026

Congrats and huge thanks every participant, for the fantastic collaboration on this one 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants