Skip to content

[diffusion] kernel fusion: ScaleResidualNormScaleShift#14289

Closed
yingluosanqian wants to merge 9 commits intosgl-project:mainfrom
yingluosanqian:fused_kernel_scale_residual_norm_scale_shift
Closed

[diffusion] kernel fusion: ScaleResidualNormScaleShift#14289
yingluosanqian wants to merge 9 commits intosgl-project:mainfrom
yingluosanqian:fused_kernel_scale_residual_norm_scale_shift

Conversation

@yingluosanqian
Copy link
Collaborator

@yingluosanqian yingluosanqian commented Dec 2, 2025

Related to #12799 (Roadmap)
Related to #14437 (apply fused kernel to diffustion model, this pr only impl kernel for sgl-kernel repo)

Motivation

This PR fuses the residual-add, normalization, and scale-shift steps to eliminate redundant memory accesses and reduce GPU communication overhead in diffusion models. i.e.

  1. residual_out = residual + gate * x
  2. normalized = norm(residual_out)
  3. modulated = (1 + scale) * normalized + shift

Modifications

  1. Add a new fused kernel ScaleResidualNormScaleShift to sgl-kernel
  2. Integrated the kernel into models such as Hunyuan and Wan

Accuracy Tests

Use the native python implementation as the baseline:

  • For FP32, we require 1e-5 absolute and 1e-5 relative error.
  • For FP16, we use 1e-2 absolute and 1e-2 relative error.

(These tolerances may be further refined based on broader model-level evaluation.)

Benchmarking and Profiling

Single Kernel Test: 4.99x speedup

Benchmarking on H200 across multiple (seq_len, hidden_dim) settings shows that fused kernel achieves an average 4.99× speedup over the unfused baseline (details here), with smaller shapes yielding larger gains. .

E2E Test

Performance Comparison Report: Wan-AI/Wan2.1-T2V-1.3B-Diffusers (1.07x speedup)

1. High-level Summary
Metric Baseline New Diff Status
E2E Latency 87712.82 ms 81929.56 ms -5783.26 ms (-6.6%)
Throughput 0.01 req/s 0.01 req/s - -
2. Stage Breakdown
Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%) Status
InputValidationStage 0.04 0.04 +0.00 +0.5% ⚪️
TextEncodingStage 2649.54 2546.71 -102.83 -3.9% ⚪️
ConditioningStage 0.05 0.02 -0.03 -60.0% ⚪️
TimestepPreparationStage 3.02 2.51 -0.51 -16.9% ⚪️
LatentPreparationStage 16.25 12.18 -4.07 -25.1% ⚪️
DenoisingStage 81457.85 75774.33 -5683.52 -7.0% ⚪️
DecodingStage 3568.70 3584.35 +15.65 +0.4% ⚪️
Metadata
  • Baseline Commit: e6ced88040d0d35622e3ee46166cd2d41bd57091
  • New Commit: e6ced88040d0d35622e3ee46166cd2d41bd57091
  • Timestamp: 2025-12-04T14:12:54.485444

Performance Comparison Report: hunyuanvideo-community/HunyuanVideo

run with gpu-nums=4

No clear speedup is seen on Hunyuan, likely because its much longer sequence length (115,200 in hunyuan vs. 32,760 in wan) makes attention kernel the dominant cost.

1. High-level Summary
Metric Baseline New Diff Status
E2E Latency 395672.21 ms 394468.82 ms -1203.40 ms (-0.3%)
Throughput 0.00 req/s 0.00 req/s - -
2. Stage Breakdown
Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%) Status
InputValidationStage 0.05 0.05 +0.01 +18.8% ⚪️
TextEncodingStage 1254.86 1416.92 +162.06 +12.9% 🔴
ConditioningStage 0.02 0.02 +0.00 +9.5% ⚪️
TimestepPreparationStage 57.30 58.78 +1.48 +2.6% ⚪️
LatentPreparationStage 324.41 323.24 -1.18 -0.4% ⚪️
DenoisingStage 376379.62 375239.32 -1140.30 -0.3% ⚪️
DecodingStage 17619.50 17395.27 -224.22 -1.3% ⚪️
Metadata
  • Baseline Commit: d276cc35c6e5c5ca9453c7a6e369a383cf206263
  • New Commit: d276cc35c6e5c5ca9453c7a6e369a383cf206263
  • Timestamp: 2025-12-04T15:27:43.426476

Profiling: 78% SOL bandwidth

Using a real case from the wan model (batch=1, seq=32760, hidden_dim=1536), the fused kernel processes approximately:

  • ~400MB of reads
    • [1 * 32760 * 1536 * 2 (residual, x) + 1 * 1 * 1536 (gate) + 1536 * 2 (scale, shift) + 1536 * 2 (weight, bias)] * 4B
  • ~400MB of writes
    • [1 * 32760 * 1536 * 2 (modulated, residual_out)] * 4B

The kernel completes in 213 µs, yielding an effective bandwidth of: 800 MB / 213 µs ≈ 3.76 TB/s. This reaches ~78% of the theoretical H200 HBM bandwidth (4.8 TB/s), consistent with ncu report.

Checklist

  • cuda kernel impl
  • unit test
    • acc test
    • perf test
  • e2e test
  • clean code

Others

Kernel Impl

  1. The kernel follows below 3-step:
  • elementwise: residual_out = residual + gate * x;
  • reduce: normalized = norm(residual_out);
  • elementwise: modulated = (1 + scale) * normalized + shift.
  1. Elementwise ops use ldg.128 (ldg.64 for fp16) to load/store 4 elements at a time.
  2. Reduction uses the Welford algorithm for numerical stability:
  • hidden_dim ≤ 1024: one warp per hidden_dim (thread → warp reduce).
  • hidden_dim > 1024: one CTA per hidden_dim (thread → warp → CTA reduce).
  1. residual_out is not stored in shared memory, relying on L1 cache instead.
  2. covers all data types and shapes supported by the original kernel.

@mickqian
Copy link
Collaborator

mickqian commented Dec 2, 2025

https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/docs/contributing.md

@yingluosanqian yingluosanqian force-pushed the fused_kernel_scale_residual_norm_scale_shift branch 2 times, most recently from f9a749a to 0ddc80b Compare December 4, 2025 15:58
@yingluosanqian yingluosanqian marked this pull request as draft December 4, 2025 16:36
@yingluosanqian yingluosanqian changed the title [WIP] [Diffusion] Kernel Fusion: ScaleResidualNormScaleShift [WIP] [diffusion] kernel fusion: ScaleResidualNormScaleShift Dec 4, 2025
@yingluosanqian yingluosanqian force-pushed the fused_kernel_scale_residual_norm_scale_shift branch from 0ddc80b to 776dffc Compare December 4, 2025 16:57
@yingluosanqian yingluosanqian marked this pull request as ready for review December 4, 2025 17:39
@yingluosanqian yingluosanqian changed the title [WIP] [diffusion] kernel fusion: ScaleResidualNormScaleShift [diffusion] kernel fusion: ScaleResidualNormScaleShift Dec 4, 2025
@BBuf
Copy link
Collaborator

BBuf commented Dec 5, 2025

For kernels related to sglang diffusion, you can create a new folder named sgl_diffusion under csrc, and then move the kernel folder from here into it.

@BBuf
Copy link
Collaborator

BBuf commented Dec 5, 2025

Please write separate test and benchmark cases for this kernel, referring to other kernels under sgl-kernel.

@yingluosanqian
Copy link
Collaborator Author

For kernels related to sglang diffusion, you can create a new folder named sgl_diffusion under csrc, and then move the kernel folder from here into it.

ok.

Please write separate test and benchmark cases for this kernel, referring to other kernels under sgl-kernel.

added: sgl-kernel/tests/sgl_diffusion/test_fused_scale_residual_norm_scale_shift.py

Due to the large number of possible combinations, I did not generate all of them but ensured each parameter appears at least once.

@yingluosanqian yingluosanqian force-pushed the fused_kernel_scale_residual_norm_scale_shift branch from 2f382a7 to e641757 Compare December 8, 2025 07:13
template <typename PtrTy, typename RegT>
__inline__ __device__ void load4_cast(const PtrTy* ptr, RegT v[4]) {
using Raw = std::conditional_t<std::is_same_v<PtrTy, float>, float4, ushort4>;
Raw raw = *reinterpret_cast<const Raw*>(ptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a pointer align check? Such as: assert(reinterpret_cast<uintptr_t>(ptr) % 16 == 0);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. I fixed the typo, improved the alignment checks, and added tests for unaligned addresses.

Previously, I only used D % 4 == 0 to decide whether vectorized loads were allowed. Now I use a stricter check based on the actual data pointer alignment:
reinterpret_cast<uintptr_t>(t.data_ptr()) % 16 == 0 for FP32, and
reinterpret_cast<uintptr_t>(t.data_ptr()) % 8 == 0 for FP16/BF16.

@yingluosanqian yingluosanqian force-pushed the fused_kernel_scale_residual_norm_scale_shift branch from 31f991d to 66c962d Compare December 9, 2025 11:26
@yingluosanqian yingluosanqian force-pushed the fused_kernel_scale_residual_norm_scale_shift branch from 66c962d to a6fd54b Compare December 9, 2025 13:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd blackwell SM100/SM120 deepseek dependencies Pull requests that update a dependency file diffusion SGLang Diffusion documentation Improvements or additions to documentation lora model-gateway Multi-modal multi-modal language model npu quant LLM Quantization run-ci sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants