[diffusion] kernel fusion: scale residual norm scale shift and add gate norm kernel fusion for Z-Image#19249
[diffusion] kernel fusion: scale residual norm scale shift and add gate norm kernel fusion for Z-Image#19249linfann wants to merge 4 commits intosgl-project:mainfrom
Conversation
…te norm kernel fusion for Z-Image
Summary of ChangesHello @linfann, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant performance optimizations for Z-Image diffusion models by implementing two new CUDA fused kernels. These kernels combine several common operations like residual connections, gating, and normalization (LayerNorm and RMSNorm) into single, more efficient GPU operations. The integration aims to reduce kernel launch overhead and has shown promising end-to-end latency improvements, particularly in the denoising phase. While overall performance has improved, a regression in the decoding stage's latency has been identified and is slated for future work. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces two new fused CUDA kernels, fused_norm_residual_gate_add_norm_scale and fused_add_gate_norm, to optimize the Z-Image model by reducing kernel launch overhead. The changes look promising and show a good performance improvement. I've identified a few issues, including a bug in the tensor validation logic, incorrect docstrings for the new kernels, and a bug in the fallback logic for the CUDA implementation. Addressing these will improve the correctness and maintainability of the code.
| elif t.ndim == 4 and (t.shape[0] != B or t.shape[2] != 1 or t.shape[3] != D): | ||
| F = t.shape[1] | ||
| if S % F != 0: | ||
| raise ValueError(f"Validate failed: S({S}) must be divisible by F({F}).") | ||
| failed = True |
There was a problem hiding this comment.
The validation logic for 4D tensors is flawed. The check if S % F != 0: is only performed if the elif condition on line 234 is met. However, for a valid shape like (B, F, 1, D), this condition is false, so the divisibility of S by F is never checked. This could lead to runtime errors if an invalid tensor is passed. This same issue exists in the duplicated validate_scale_shift function in python/sglang/jit_kernel/diffusion/cutedsl/norm_residual_gate_add_norm_scale.py.
| elif t.ndim == 4 and (t.shape[0] != B or t.shape[2] != 1 or t.shape[3] != D): | |
| F = t.shape[1] | |
| if S % F != 0: | |
| raise ValueError(f"Validate failed: S({S}) must be divisible by F({F}).") | |
| failed = True | |
| elif t.ndim == 4: | |
| F = t.shape[1] | |
| if S % F != 0: | |
| raise ValueError(f"Validate failed: S({S}) must be divisible by F({F}).") | |
| if t.shape[0] != B or t.shape[2] != 1 or t.shape[3] != D: | |
| failed = True |
| gate: torch.Tensor | int, | ||
| scale: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192: |
There was a problem hiding this comment.
The condition to fall back to the native implementation is incorrect. It uses and where it should use or, and the upper bound check is x.shape[-1] <= 8192 instead of x.shape[-1] > 8192. This could cause the fused kernel to be called with unsupported dimensions (e.g., D > 8192), leading to a ValueError from the kernel instead of a graceful fallback. A similar bug exists in _AddGateNorm.forward_cuda on line 636.
| if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192: | |
| if x.shape[-1] % 256 != 0 or x.shape[-1] > 8192: |
| """ | ||
| Fuse: norm(x) * (1 + scale) + shift | ||
| where norm is either layernorm or rmsnorm. | ||
|
|
||
| Expects: | ||
| - x: [B, S, D] | ||
| - weight/bias: None, [D] | ||
| - scale/shift: [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D] | ||
| - norm_type: str, "layer" or "rms" | ||
| - eps: Optional[float], default: 1e-5 | ||
|
|
There was a problem hiding this comment.
The docstring for fused_add_gate_norm appears to be incorrect. It describes a norm(x) * (1 + scale) + shift operation and expects scale/shift parameters, which doesn't match the function's implementation or signature. The actual operation is x + gate * norm(residual). Please update the docstring to accurately reflect the function's behavior and parameters.
| """ | |
| Fuse: norm(x) * (1 + scale) + shift | |
| where norm is either layernorm or rmsnorm. | |
| Expects: | |
| - x: [B, S, D] | |
| - weight/bias: None, [D] | |
| - scale/shift: [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D] | |
| - norm_type: str, "layer" or "rms" | |
| - eps: Optional[float], default: 1e-5 | |
| """ | |
| Fuse: x + gate * norm(residual) | |
| where norm is either layernorm or rmsnorm. | |
| Expects: | |
| - x: [B, S, D] | |
| - residual: [B, S, D] | |
| - gate: None, or broadcastable to [B, S, D] | |
| - weight/bias: None, [D] | |
| - norm_type: str, "layer" or "rms" | |
| - eps: Optional[float], default: 1e-5 | |
| """ | ||
| Fuse: norm(x) * (1 + scale) + shift | ||
| where norm is either layernorm or rmsnorm. | ||
|
|
||
| Expects: | ||
| - x: [B, S, D] | ||
| - weight/bias: None, [D] | ||
| - scale/shift: [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D] | ||
| - norm_type: str, "layer" or "rms" | ||
| - eps: Optional[float], default: 1e-5 | ||
|
|
There was a problem hiding this comment.
The docstring for fused_norm_residual_gate_add_norm_scale seems incorrect. It describes the operation as norm(x) * (1 + scale) + shift, which doesn't match the implementation. The function actually computes norm_out = norm2(x + gate * norm1(residual)) * (scale + 1) and also returns the intermediate residual_out. The Expects section also incorrectly mentions shift. Please update the docstring for clarity.
"""
Fuse:
residual_out = x + gate * norm1(residual)
norm_out = norm2(residual_out) * (scale + 1.0)
where norm is either layernorm or rmsnorm.
Returns:
A tuple of (norm_out, residual_out).
Expects:
- x: [B, S, D]
- residual: [B, S, D]
- gate: None, or broadcastable to [B, S, D]
- weight1/bias1: None, [D] for norm1
- weight2/bias2: None, [D] for norm2
- scale: broadcastable to [B, S, D]
- norm_type: str, "layer" or "rms"
- eps: Optional[float], default: 1e-5| stacklevel=2, | ||
| ) | ||
| return self.forward_native(residual, x, gate, scale) | ||
| # todo use fused kernel |
| x: torch.Tensor, | ||
| gate: torch.Tensor | int, | ||
| ) -> torch.Tensor: | ||
| logger.info("### use cuda fused_add_gate_norm") |
There was a problem hiding this comment.
This logger.info call appears to be for debugging. To avoid verbose logs in production, please consider removing it or changing it to logger.debug. A similar log message exists on line 676.
| logger.info("### use cuda fused_add_gate_norm") | |
| logger.debug("### use cuda fused_add_gate_norm") |
efbe6d1 to
3ee3e75
Compare
| def __init__(self, D: int, norm_type: str): | ||
| self.D = D | ||
| self.norm_type = norm_type # "layer" or "rms" | ||
| self.num_warps = self.D // 256 # num of warps per cta |
There was a problem hiding this comment.
This means that each warp processes 256 elements, which means each CUDA thread processes 8 elements, is that correct?
There was a problem hiding this comment.
@linfann I think there are two points worth discussing here:
- Shouldn't this be a ceil_div?
- Blackwell supports 256-bit ld/st, and a warp can process 512 elements.
There was a problem hiding this comment.
This means that each warp processes 256 elements, which means each CUDA thread processes 8 elements, is that correct?
This kernel assumes that the reduction dimension (N) is a multiple of 256. Fixing the number of warps and forcing each warp to process exactly 256 elements is not the optimal strategy for performance. Instead, we only need to assume that N is a multiple of 8 and use predicates to avoid unnecessary load/store/compute operations like this.
Since these kernels share the same norm template, this change will also affect another kernel. @linfann You can try implementing it like this. If it is too complicated, we can leave it for me to handle in the next PR.
| @cute.jit | ||
| def copy_if(src, dst): | ||
| if cutlass.const_expr( | ||
| isinstance(src, cute.Tensor) and isinstance(src, cute.Tensor) |
There was a problem hiding this comment.
Yes, this typo also appears in the referenced PR.
|
In ZImage there are two execution paths: one with
|
OK,I will implement the latter one |
| _COMPILE_CACHE = {} | ||
|
|
||
|
|
||
| def to_cute_arg( |
There was a problem hiding this comment.
can we make this function a common utility function?


Motivation
Optimize Z-Image via kernel fusion (refer to #14717 ).
This PR:
fused_norm_residual_gate_add_norm_scaleThe kernel fusion reduces kernel launch overhead for Z-Image.
Modifications
sgl_kernel, add thefused_norm_residual_gate_add_norm_scaleandfused_add_gate_normCUDA kernels based on CUTLASS.Support both Layernorm and RMSNorm.
Accuracy Tests
Benchmarking and Profiling
Benchmark
bench_fused_norm_residual_gate_add_norm_scale
bench_fused_add_gate_norm
Profiling
no compile
Command:
1. High-level Summary
2. Stage Breakdown
compile
Command:
1. High-level Summary
2. Stage Breakdown
Confusion:
--enable-torch-compilecost more timeChecklist