Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
1ee2726
init cutedsl kernel
AichenF Nov 26, 2025
6edc5a1
cutedsl scale shift kernel
AichenF Nov 27, 2025
3483b2f
fused layernorm scale shift kernel v1
Nov 27, 2025
c68a1e6
support bf16
jianyingzhu Nov 28, 2025
6069dc5
bf16 accuracy beat triton with speedup
jianyingzhu Nov 28, 2025
87a5fad
update test fused kernel
jianyingzhu Dec 2, 2025
34e1e89
fused kernel into sgl-kernel, build success
jianyingzhu Dec 2, 2025
2008766
add residual gate fused kernel
jianyingzhu Dec 8, 2025
51de512
update calling func
jianyingzhu Dec 8, 2025
4f1149b
major update
jianyingzhu Dec 8, 2025
ca9a1b6
build sgl_kernel on a100 and pass pytest
jianyingzhu Dec 8, 2025
0be916b
build success s, run wan success with speedup
jianyingzhu Dec 9, 2025
7475998
move test.py
AichenF Dec 9, 2025
bdf9018
update test.py
AichenF Dec 9, 2025
de8ccfd
code clean
AichenF Dec 9, 2025
e45937e
tune block size for different m for better performance
AichenF Dec 9, 2025
fc382e6
fused kernel
jianyingzhu Dec 9, 2025
413b4b9
update layernorm.py
jianyingzhu Dec 10, 2025
d6a6db3
kernel upd
AichenF Dec 10, 2025
fbbf772
support hunyuan video
jianyingzhu Dec 10, 2025
fac69d4
simplify code
jianyingzhu Dec 10, 2025
35f8d82
remove log
jianyingzhu Dec 10, 2025
313d815
precommit to format code
Dec 10, 2025
41cf85d
tolerance to 5e-2
jianyingzhu Dec 10, 2025
82d3676
add more cases in test
jianyingzhu Dec 10, 2025
60d49ef
add more cases in test
jianyingzhu Dec 10, 2025
eb42f4a
restore kernel
AichenF Dec 11, 2025
0673d66
add no affine fused kernel for qwen-image
AichenF Dec 11, 2025
25868e3
upd
AichenF Dec 11, 2025
704aba2
update default and benchmark
jianyingzhu Dec 11, 2025
1819dd1
upd test
AichenF Dec 11, 2025
943338a
clean kernel launch
AichenF Dec 11, 2025
97b0dff
support 1,c / 1,1,c broadcast shape
jianyingzhu Dec 11, 2025
e0eedc9
pre-commit formatting code
jianyingzhu Dec 11, 2025
742ff41
upd benchmark
AichenF Dec 11, 2025
a869fd0
formatting code
jianyingzhu Dec 11, 2025
6269651
restore benchmark script
jianyingzhu Dec 12, 2025
f849d7f
fix _cuda and sgl_kernel import
jianyingzhu Dec 12, 2025
83f8933
code formatting
jianyingzhu Dec 12, 2025
13d0425
1) merge w/o affine kernels; 2) fix eps inside kernel; 3) add optiona…
Dec 14, 2025
b870267
tests: add tests
Dec 14, 2025
872cdc0
support optional gate/weight/bias in scale_residual_layernorm_scale_s…
Dec 14, 2025
0ea11e3
sglang bind
Dec 14, 2025
e9857a5
use Optional[torch.Tensor] instead
Dec 14, 2025
90e8576
fix param order and code formatting
jianyingzhu Dec 15, 2025
7eb8b23
1) rms norm and tests; 2) clean dispatch code with zero-overhead;
Dec 16, 2025
99271d2
removing the incorrectly transmitted file.
Dec 16, 2025
c095573
modify CMakeLists
Dec 16, 2025
4392215
update layernorm.py
Dec 16, 2025
c23e96b
update benchmark
Dec 16, 2025
c3505fd
kernel rename: layernorm -> norm (both layer and rms)
Dec 16, 2025
2f1f6e4
check: add norm_type check
Dec 16, 2025
0c7d586
Merge branch 'main' into feat/layernorm_scale_shift_kernel
jianyingzhu Dec 19, 2025
5915743
fix linting
jianyingzhu Dec 19, 2025
d6b2df8
Merge branch 'main' into feat/layernorm_scale_shift_kernel
yingluosanqian Dec 29, 2025
9f0fd3b
transfer to jit_kernel
yingluosanqian Dec 29, 2025
fa994dd
ditto
yingluosanqian Dec 29, 2025
dd60ce7
update layernorm
yingluosanqian Dec 29, 2025
e1e4c31
Merge branch 'main' into feat/layernorm_scale_shift_kernel
yingluosanqian Dec 30, 2025
739e55c
update layernorm; resolve conflict in qwen_image
yingluosanqian Dec 30, 2025
45c1833
Merge branch 'main' into feat/layernorm_scale_shift_kernel
yingluosanqian Dec 30, 2025
24677e8
fused further
yingluosanqian Jan 2, 2026
dea1a55
Merge remote-tracking branch 'upstream/main' into layernorm_scale_shi…
yingluosanqian Jan 2, 2026
bd032de
remove compute_dtype because always fp32 in fused kernel
jianyingzhu Jan 6, 2026
931ffa1
remove compute_dtype in wan because always fp32 in fused kernel
jianyingzhu Jan 6, 2026
8806b8b
update hunyuan LayerNormScaleShift
jianyingzhu Jan 6, 2026
77ae2f4
Merge aichenf/main into feat/layernorm_scale_shift_kernel
jianyingzhu Jan 6, 2026
2200b69
rename get_args
jianyingzhu Jan 6, 2026
9b22ebb
clean code
yingluosanqian Jan 6, 2026
1b3d928
update benchmark
yingluosanqian Jan 6, 2026
03ef897
update benchmark using @triton.testing.perf_report
jianyingzhu Jan 7, 2026
0537424
replace T4 with aligned_vector, update type checking
jianyingzhu Jan 7, 2026
8f5516a
simplifying test code, gate_mode to template args
jianyingzhu Jan 7, 2026
328b678
Merge branch 'main' into feat/layernorm_scale_shift_kernel
jianyingzhu Jan 7, 2026
fe7f822
revert text_encoding.py
jianyingzhu Jan 7, 2026
7b7883b
update type checking of qwen-image modulate
jianyingzhu Jan 7, 2026
13cc7d4
combine 2d and 4d kernel using index mapping, simplify code
jianyingzhu Jan 11, 2026
ac07be9
Merge upstream/main into feat/layernorm_scale_shift_kernel
jianyingzhu Jan 12, 2026
877e9f5
Merge branch 'main' into feat/layernorm_scale_shift_kernel
jianyingzhu Jan 14, 2026
ed7cf6d
fix lint
jianyingzhu Jan 14, 2026
e2b5458
Merge remote-tracking branch 'origin/main' into feat/layernorm_scale_…
yingluosanqian Jan 14, 2026
caa5d31
.
yingluosanqian Jan 14, 2026
139b4e3
.
yingluosanqian Jan 14, 2026
ed08703
update block reduce sum
jianyingzhu Jan 15, 2026
3325baf
update norm template
jianyingzhu Jan 15, 2026
307e8c3
update norm template
jianyingzhu Jan 15, 2026
223da94
update norm template
jianyingzhu Jan 16, 2026
41f7c47
Merge branch 'main' into feat/layernorm_scale_shift_kernel
jianyingzhu Jan 16, 2026
c64935a
clean and bugfix
yingluosanqian Jan 20, 2026
d568a53
clean
yingluosanqian Jan 20, 2026
aab5b2c
clean
yingluosanqian Jan 20, 2026
fb1dfcf
.
yingluosanqian Jan 20, 2026
c3cef4a
.
yingluosanqian Jan 20, 2026
0dd2b6e
clean
yingluosanqian Jan 20, 2026
a817cd4
finish cutedsl kernel
yingluosanqian Jan 22, 2026
0425164
Merge remote-tracking branch 'upstream/main' into cutedsl_norm
yingluosanqian Jan 22, 2026
5377418
tensor validation
yingluosanqian Jan 22, 2026
beff96c
validate dtype
yingluosanqian Jan 22, 2026
76f927b
tvm ffi
yingluosanqian Jan 25, 2026
7e7c01f
clean
yingluosanqian Jan 25, 2026
3bb67d7
format
yingluosanqian Jan 25, 2026
e560ec7
use fp32 as acc dtype to improve precision
yingluosanqian Jan 25, 2026
767f55e
support gate as None
yingluosanqian Jan 27, 2026
55312f8
move kernel to diffusion/ dir
yingluosanqian Jan 27, 2026
c7fef08
Merge branch 'main' into feat/layernorm_scale_shift_kernel
yingluosanqian Jan 28, 2026
6001c74
Merge branch 'main' into feat/layernorm_scale_shift_kernel
BBuf Jan 29, 2026
b766f4c
add fallback, clean code
yingluosanqian Jan 29, 2026
e78f4d4
fallback fix
yingluosanqian Jan 29, 2026
e9c24ff
workaround, enable video warmup
yingluosanqian Jan 31, 2026
60c6241
Merge branch 'main' into feat/layernorm_scale_shift_kernel
BBuf Feb 1, 2026
abc2961
Merge remote-tracking branch 'upstream/main' into feat/layernorm_scal…
yingluosanqian Feb 2, 2026
ae7f3e2
[amd]bugfix for qwen
yingluosanqian Feb 2, 2026
588252d
[amd] bugfix
yingluosanqian Feb 3, 2026
80f342f
Merge branch 'main' into feat/layernorm_scale_shift_kernel
yingluosanqian Feb 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_fused_norm_scale_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Benchmarks SGLang fused layernorm/rmsnorm scale shift kernels
# 1. fused_norm_scale_shift
# 2. fused_scale_residual_norm_scale_shift
import itertools
from typing import Tuple

import torch
import triton
import triton.testing

from sglang.jit_kernel.benchmark.utils import is_in_ci
from sglang.multimodal_gen.runtime.layers.layernorm import (
LayerNormScaleShift,
RMSNormScaleShift,
ScaleResidualLayerNormScaleShift,
ScaleResidualRMSNormScaleShift,
)

if is_in_ci():
B_RANGE, S_RANGE, D_RANGE = [1], [128], [1024]
else:
B_RANGE, S_RANGE, D_RANGE = [1], [128, 1024, 4096], [1024, 3072, 4096]

NORM_TYPE_RANGE = ["layer", "rms"]
AFFINE_RANGE = [True, False]
DTYPE = torch.bfloat16
DEVICE = "cuda"
EPS = 1e-5
LINE_VALS = ["native", "cuda"]
LINE_NAMES = ["SGLang Native", "SGLang Fused"]
STYLES = [("red", "-"), ("blue", "--")]
config = list(
itertools.product(B_RANGE, S_RANGE, D_RANGE, NORM_TYPE_RANGE, AFFINE_RANGE)
)


def preprocess_layer(layer, affine: bool, D: int, DTYPE: torch.dtype):
if affine:
weight = torch.randn(D, dtype=DTYPE, device=DEVICE)
bias = torch.randn(D, dtype=DTYPE, device=DEVICE)
with torch.no_grad():
layer.norm.weight.copy_(weight)
if hasattr(layer.norm, "bias"):
layer.norm.bias.copy_(bias)
layer.requires_grad_(False)
return layer.to(DEVICE)


# ============================================================================
# Benchmark 1: fused_norm_scale_shift
# ============================================================================
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["B", "S", "D", "norm_type", "affine"],
x_vals=config,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="fused_norm_scale_shift",
args={},
)
)
def bench_fused_norm_scale_shift(
B: int, S: int, D: int, norm_type, affine: bool, provider: str
) -> Tuple[float, float, float]:
x = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)
scale = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)
shift = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)
if norm_type == "layer":
layer = LayerNormScaleShift(D, EPS, affine, dtype=DTYPE)
else:
layer = RMSNormScaleShift(D, EPS, affine, dtype=DTYPE)
layer = preprocess_layer(layer, affine, D, DTYPE)
if provider == "native":
fn = lambda: layer.forward_native(x, shift, scale)
else:
fn = lambda: layer.forward_cuda(x, shift, scale)

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # convert to us


# ============================================================================
# Benchmark 2: fused_scale_residual_norm_scale_shift
# ============================================================================
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["B", "S", "D", "norm_type", "affine"],
x_vals=config,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="fused_scale_residual_norm_scale_shift",
args={},
)
)
def bench_fused_scale_residual_norm_scale_shift(
B: int, S: int, D: int, norm_type, affine: bool, provider: str
) -> Tuple[float, float, float]:
residual = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)
x = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)
scale = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)
shift = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE)
gate = torch.randn(B, 1, D, dtype=DTYPE, device=DEVICE)
if norm_type == "layer":
layer = ScaleResidualLayerNormScaleShift(D, EPS, affine, dtype=DTYPE).to(DEVICE)
else:
layer = ScaleResidualRMSNormScaleShift(D, EPS, affine, dtype=DTYPE).to(DEVICE)
layer = preprocess_layer(layer, affine, D, DTYPE)
if provider == "native":
fn = lambda: layer.forward_native(residual, x, gate, shift, scale)
else:
fn = lambda: layer.forward_cuda(residual, x, gate, shift, scale)

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # convert to us


if __name__ == "__main__":
print(f"\n{'='*80}")
print("Benchmark: fused_norm_scale_shift")
print(f"{'='*80}\n")
bench_fused_norm_scale_shift.run(print_data=True)

print(f"\n{'='*80}")
print("Benchmark: fused_scale_residual_norm_scale_shift")
print(f"{'='*80}\n")
bench_fused_scale_residual_norm_scale_shift.run(print_data=True)
201 changes: 201 additions & 0 deletions python/sglang/jit_kernel/diffusion/cutedsl/common/norm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from typing import Optional, Tuple, Union

import cutlass
import cutlass.cute as cute
import torch
from einops import rearrange

from sglang.jit_kernel.diffusion.cutedsl.common.reduce import (
cta_reduce_sum,
warp_reduce_sum,
)


@cute.jit
def apply_norm_cta(
norm_type: cutlass.Constexpr,
num_warps: cutlass.Constexpr,
tidx: cutlass.Int32,
tXrX: cute.Tensor,
tWrW: Optional[cute.Tensor],
tBrB: Optional[cute.Tensor],
D: Union[cutlass.Int32, cutlass.Constexpr],
eps: Union[cutlass.Float32, cutlass.Constexpr],
) -> cute.Tensor:
if cutlass.const_expr(norm_type == "rms"):
return apply_rmsnorm_cta(num_warps, tidx, tXrX, tWrW, D, eps)
else:
return apply_layernorm_cta(num_warps, tidx, tXrX, tWrW, tBrB, D, eps)


@cute.jit
def apply_rmsnorm_cta(
num_warps: Union[cutlass.Int32, cutlass.Constexpr],
tidx: cutlass.Int32,
tXrX: cute.Tensor,
tWrW: Optional[cute.Tensor],
D: Union[cutlass.Int32, cutlass.Constexpr],
eps: Union[cutlass.Float32, cutlass.Constexpr],
) -> cute.Tensor:
"""
RMSNorm:
y[i] = x[i] / sqrt(sum(x ^ 2) / D + eps) * w[i]
"""
val = cute.Float32(0.0)
for idx in range(cute.size(tXrX)):
# Accumulate in FP32 to improve numerical precision.
x_fp32 = tXrX[idx].to(cutlass.Float32)
val += x_fp32 * x_fp32
val = warp_reduce_sum(val)
acc_sq = cta_reduce_sum(val, num_warps, tidx)
factor = cute.rsqrt(acc_sq / D + eps)
tNrN = cute.make_fragment_like(tXrX)
if cutlass.const_expr(isinstance(tWrW, cute.Tensor)):
tNrN.store((tXrX.load() * factor * tWrW.load()).to(tNrN.element_type))
else:
tNrN.store((tXrX.load() * factor).to(tNrN.element_type))
return tNrN


@cute.jit
def apply_layernorm_cta(
num_warps: Union[cutlass.Int32, cutlass.Constexpr],
tidx: cutlass.Int32,
tXrX: cute.Tensor,
tWrW: Optional[cute.Tensor],
tBrB: Optional[cute.Tensor],
D: Union[cutlass.Int32, cutlass.Constexpr],
eps: Union[cutlass.Float32, cutlass.Constexpr],
) -> cute.Tensor:
"""
LayerNorm:
mean = sum(x) / D
var = sum((x - mean) ^ 2) / D
y[i] = (x[i] - mean) / sqrt(var + eps) * w[i] + b[i]
"""
# Reduce mean
val = cute.Float32(0.0)
for idx in range(cute.size(tXrX)):
# Accumulate in FP32 to improve numerical precision.
val += tXrX[idx].to(cutlass.Float32)
val = warp_reduce_sum(val)
val = cta_reduce_sum(val, num_warps, tidx)
mean = val / D
# Reduce variance
val = cute.Float32(0.0)
for idx in range(cute.size(tXrX)):
# Accumulate in FP32 to improve numerical precision.
x_fp32 = tXrX[idx].to(cutlass.Float32)
val += (x_fp32 - mean) * (x_fp32 - mean)
val = warp_reduce_sum(val)
val = cta_reduce_sum(val, num_warps, tidx)
factor = cute.rsqrt(val / D + eps)
# Normalize
tNrN = cute.make_fragment_like(tXrX)
if cutlass.const_expr(
isinstance(tWrW, cute.Tensor) and isinstance(tBrB, cute.Tensor)
):
tNrN.store(
((tXrX.load() - mean) * factor * tWrW.load() + tBrB.load()).to(
tNrN.element_type
)
)
else:
tNrN.store(((tXrX.load() - mean) * factor).to(tNrN.element_type))
return tNrN


################################################################################
# BSFD Indexing
################################################################################
# In diffusion norm-fusion kernels, we compute `norm(x) + y`, where
# `x` has shape [B, S, D] and `y` may come in various broadcastable forms:
# [1], [D], [1, D], [1, 1, D], [B, D], [B, 1, D], [B, S, D], or [B, F, 1, D].
#
# For a given (batch_id, seq_id), the index mapping for `y` falls into 3 cases:
# 1) Scalar broadcast [1]:
# (batch_id, seq_id, *) -> (0)
# 2) Frame-based BSFD broadcast [B, F, 1, D]:
# frame_id = seq_id // len_frame
# (batch_id, seq_id, *) -> (batch_id, frame_id, *)
# 3) All other cases:
# `y` is broadcast to [B, S, D] (via view/expand, no materialization),
# and indexed as (batch_id, seq_id, *).
#
# This helper normalizes `y` into a BSFD-compatible view so that kernel
# indexing logic remains simple and uniform.
################################################################################


def broadcast_tensor_for_bsfd(
tensor: Union[Optional[torch.Tensor], int],
B: int,
S: int,
D: int,
) -> Union[Optional[torch.Tensor], int]:
"""
Broadcast to (B, S, D) without memory copy for following shapes:
- [D], [1, D], [1, 1, D], [B, D], [B, 1, D], [B, S, D].
"""

# Return directly for non-tensor value
if not isinstance(tensor, torch.Tensor):
return tensor

if tensor.ndim == 1:
# Scalar [1] is preserved as-is and handled specially in CuTe kernel.
if tensor.numel() == 1:
return tensor
return rearrange(tensor, "d -> 1 1 d").expand(B, S, D)
if tensor.ndim == 2:
return rearrange(tensor, "b d -> b 1 d").expand(B, S, D)
if tensor.ndim == 3:
return tensor.expand(B, S, D)
if tensor.ndim == 4:
return tensor
raise ValueError(f"BSFD broadcast: unsupported tensor ndim: {tensor.ndim}.")


@cute.jit
def tensor_slice_for_bsfd(
mV: cute.Tensor,
thr_copy: cute.ThrCopy,
batch_id: cutlass.Int32,
seq_id: cutlass.Int32,
S: Union[cutlass.Int32, cutlass.Constexpr],
D: Union[cutlass.Int32, cutlass.Constexpr],
) -> Tuple[cute.Tensor, cute.Tensor]:
"""
Slice a BSFD-compatible tensor into a per-thread gmem tile and rmem fragment.

Given a logical (batch_id, seq_id), this helper selects the corresponding
D-length slice from `mV` and prepares it for vectorized copy.
"""
gV: cute.Tensor
if cutlass.const_expr(cute.is_static(mV.layout) and cute.size(mV.layout) == 1):
# build a ((1,1),(1,)) layout so it could broadcast-align with the
# regular rmem fragment shape ((4,1),(k,)).
layout = cute.make_layout(shape=((1, 1), (1,)))
tVgV = cute.make_tensor(mV.iterator, layout)
tVrV = cute.make_rmem_tensor(layout, mV.element_type)
return tVgV, tVrV

# Use `local_tile` instead of direct indexing to preserve gmem base pointer
# alignment required for vectorized loads.
if cutlass.const_expr(len(mV.shape) == 1):
gV = mV
elif cutlass.const_expr(len(mV.shape) == 3):
gV = cute.local_tile(mV, tiler=(1, 1, D), coord=(batch_id, seq_id, 0))
gV = gV[0, 0, None]
elif cutlass.const_expr(len(mV.shape) == 4):
# Compute frame length at runtime (instead of compile time) to avoid
# specializing kernels on the frame dimension.
frame_len = S // mV.shape[1]
frame_id = seq_id // frame_len
gV = cute.local_tile(mV, tiler=(1, 1, 1, D), coord=(batch_id, frame_id, 0, 0))
gV = gV[0, 0, 0, None]
else:
raise NotImplementedError(f"BSFD slice: unsupported shape {mV.shape}.")
tVgV = thr_copy.partition_S(gV)
tVrV = cute.make_fragment_like(tVgV, tVgV.element_type)
return tVgV, tVrV
33 changes: 33 additions & 0 deletions python/sglang/jit_kernel/diffusion/cutedsl/common/reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import math

import cutlass
import cutlass.cute as cute


@cute.jit
def warp_reduce_sum(val: cute.Numeric, reduce_size: int = 32) -> cute.Numeric:
iters = int(math.log2(reduce_size))
for i in range(iters):
val = val + cute.arch.shuffle_sync_down(val, offset=1 << (iters - i - 1))
return val


@cute.jit
def cta_reduce_sum(
val: cute.Numeric, num_warps: cutlass.Constexpr, tidx: cutlass.Int32
) -> cute.Numeric:
smem = cutlass.utils.SmemAllocator()
acc = smem.allocate_tensor(cutlass.Float32, num_warps)
warp_id = tidx >> 5
lane_id = tidx & 31
if lane_id == 0:
acc[warp_id] = val
cute.arch.sync_threads()
if warp_id == 0:
val = acc[lane_id] if lane_id < num_warps else cutlass.Float32(0)
val = warp_reduce_sum(val)
if lane_id == 0:
acc[0] = val
cute.arch.sync_threads()
val = acc[0]
return val
Loading
Loading