[diffusion] kernel fusion: ScaleResidualNormScaleShift#14289
[diffusion] kernel fusion: ScaleResidualNormScaleShift#14289yingluosanqian wants to merge 9 commits intosgl-project:mainfrom
Conversation
f9a749a to
0ddc80b
Compare
0ddc80b to
776dffc
Compare
|
For kernels related to sglang diffusion, you can create a new folder named sgl_diffusion under csrc, and then move the kernel folder from here into it. |
|
Please write separate test and benchmark cases for this kernel, referring to other kernels under |
ok.
added: sgl-kernel/tests/sgl_diffusion/test_fused_scale_residual_norm_scale_shift.py Due to the large number of possible combinations, I did not generate all of them but ensured each parameter appears at least once. |
2f382a7 to
e641757
Compare
sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/kernel_welford.cuh
Outdated
Show resolved
Hide resolved
| template <typename PtrTy, typename RegT> | ||
| __inline__ __device__ void load4_cast(const PtrTy* ptr, RegT v[4]) { | ||
| using Raw = std::conditional_t<std::is_same_v<PtrTy, float>, float4, ushort4>; | ||
| Raw raw = *reinterpret_cast<const Raw*>(ptr); |
There was a problem hiding this comment.
Should we add a pointer align check? Such as: assert(reinterpret_cast<uintptr_t>(ptr) % 16 == 0);
There was a problem hiding this comment.
thanks. I fixed the typo, improved the alignment checks, and added tests for unaligned addresses.
Previously, I only used D % 4 == 0 to decide whether vectorized loads were allowed. Now I use a stricter check based on the actual data pointer alignment:
reinterpret_cast<uintptr_t>(t.data_ptr()) % 16 == 0 for FP32, and
reinterpret_cast<uintptr_t>(t.data_ptr()) % 8 == 0 for FP16/BF16.
sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/kernel_welford.cuh
Outdated
Show resolved
Hide resolved
31f991d to
66c962d
Compare
66c962d to
a6fd54b
Compare
Related to #12799 (Roadmap)
Related to #14437 (apply fused kernel to diffustion model, this pr only impl kernel for
sgl-kernelrepo)Motivation
This PR fuses the residual-add, normalization, and scale-shift steps to eliminate redundant memory accesses and reduce GPU communication overhead in diffusion models. i.e.
Modifications
sgl-kernelAccuracy Tests
Use the native python implementation as the baseline:
(These tolerances may be further refined based on broader model-level evaluation.)
Benchmarking and Profiling
Single Kernel Test: 4.99x speedup
Benchmarking on H200 across multiple
(seq_len, hidden_dim)settings shows that fused kernel achieves an average 4.99× speedup over the unfused baseline (details here), with smaller shapes yielding larger gains. .E2E Test
Performance Comparison Report: Wan-AI/Wan2.1-T2V-1.3B-Diffusers (1.07x speedup)
1. High-level Summary
2. Stage Breakdown
Metadata
e6ced88040d0d35622e3ee46166cd2d41bd57091e6ced88040d0d35622e3ee46166cd2d41bd57091Performance Comparison Report: hunyuanvideo-community/HunyuanVideo
No clear speedup is seen on Hunyuan, likely because its much longer sequence length (115,200 in hunyuan vs. 32,760 in wan) makes attention kernel the dominant cost.
1. High-level Summary
2. Stage Breakdown
Metadata
d276cc35c6e5c5ca9453c7a6e369a383cf206263d276cc35c6e5c5ca9453c7a6e369a383cf206263Profiling: 78% SOL bandwidth
Using a real case from the wan model (batch=1, seq=32760, hidden_dim=1536), the fused kernel processes approximately:
The kernel completes in 213 µs, yielding an effective bandwidth of:
800 MB / 213 µs ≈ 3.76 TB/s. This reaches ~78% of the theoretical H200 HBM bandwidth (4.8 TB/s), consistent with ncu report.Checklist
Others
Kernel Impl
residual_out = residual + gate * x;normalized = norm(residual_out);modulated = (1 + scale) * normalized + shift.Welfordalgorithm for numerical stability: