diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index ec8f680a7876..f3170d4147b0 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -10,11 +10,7 @@ import torch.nn.functional as F from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp -from sglang.multimodal_gen.runtime.layers.triton_ops import ( - fuse_scale_shift_kernel, - norm_infer, - rms_norm_fn, -) +from sglang.multimodal_gen.runtime.layers.triton_ops import norm_infer, rms_norm_fn from sglang.multimodal_gen.runtime.utils.common import ( get_bool_env_var, is_cpu, @@ -30,7 +26,7 @@ _is_cpu = is_cpu() _is_xpu = is_xpu() -from sgl_kernel import fused_add_rmsnorm, rmsnorm +from sgl_kernel import fused_add_rmsnorm, rmsnorm, scale_residual_norm_scale_shit # Copied and adapted from sglang @@ -269,100 +265,129 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ).to(origin_dtype) -class ScaleResidualLayerNormScaleShift(nn.Module): - """ - Fused operation that combines: - 1. Gated residual connection - 2. LayerNorm - 3. Scale and shift operations - - This reduces memory bandwidth by combining memory-bound operations. - """ +@CustomOp.register("scale_residual_norm_scale_shift") +class ScaleResidualNormScaleShift(CustomOp): def __init__( self, hidden_size: int, + eps=1e-6, norm_type: str = "rms", - eps: float = 1e-6, elementwise_affine: bool = False, - dtype: torch.dtype = torch.float32, - compute_dtype: torch.dtype | None = None, - prefix: str = "", + bias: bool = False, + device=None, + dtype=None, ): super().__init__() - if norm_type == "rms": - self.norm = RMSNorm( - hidden_size, has_weight=elementwise_affine, eps=eps, dtype=dtype + self.hidden_size = hidden_size + self.eps = eps + self.norm_type = norm_type.lower() + self.norm = nn.Module() + + factory_kwargs = {"device": device, "dtype": dtype} + if elementwise_affine: + self.norm.weight = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs) ) - elif norm_type == "layer": - if compute_dtype == torch.float32: - self.norm = FP32LayerNorm( - hidden_size, elementwise_affine=elementwise_affine, eps=eps + if self.norm_type == "layer" and bias: + self.norm.bias = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs) ) else: - self.norm = LayerNorm( - hidden_size, - elementwise_affine=elementwise_affine, - eps=eps, - dtype=dtype, - ) + self.norm.register_parameter("bias", None) else: - raise NotImplementedError(f"Norm type {norm_type} not implemented") + self.norm.register_parameter("weight", None) + self.norm.register_parameter("bias", None) - def forward( + def forward_cuda( self, residual: torch.Tensor, x: torch.Tensor, gate: torch.Tensor | int, shift: torch.Tensor, scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply gated residual connection, followed by layernorm and - scale/shift in a single fused operation. - - Returns: - Tuple containing: - - normalized and modulated output of shape: [batch_size, seq_len, inner_dim] - - residual value (value after residual connection - but before normalization) - """ - # x.shape: [batch_size, seq_len, inner_dim] - # Apply residual connection with gating - if isinstance(gate, int): - # used by cross-attention, should be 1 - assert gate == 1 - residual_output = residual + x - elif isinstance(gate, torch.Tensor): + ): + scale_residual_norm_scale_shit( + residual, + x, + gate if isinstance(gate, torch.Tensor) else None, + self.norm.weight, + self.norm.bias, + scale, + shift, + self.eps, + self.norm_type, + ) + + def forward_native( + self, + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | int, + shift: torch.Tensor, + scale: torch.Tensor, + ): + # 1. residual add + if isinstance(gate, torch.Tensor): if gate.dim() == 4: # gate.shape: [batch_size, num_frames, 1, inner_dim] num_frames = gate.shape[1] frame_seqlen = x.shape[1] // num_frames - residual_output = residual + ( + residual_out = residual + ( x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate ).flatten(1, 2) else: - # used by bidirectional self attention # gate.shape: [batch_size, 1, inner_dim] - residual_output = residual + x * gate + residual_out = residual + x * gate else: - raise ValueError(f"Gate type {type(gate)} not supported") - # residual_output.shape: [batch_size, seq_len, inner_dim] - - # Apply normalization - normalized = self.norm(residual_output) - - # modulated = fused_scale_shift( - # normalized, - # scale, - # shift, - # ) - modulated = fuse_scale_shift_kernel( - normalized, - scale, - shift, - ) - return modulated, residual_output + residual_out = residual + x * gate + # 2. normalize + if self.norm_type == "layer": # LayerNorm + mean = residual_out.mean(dim=-1, keepdim=True) + var = residual_out.var(dim=-1, unbiased=False, keepdim=True) + normalized = (residual_out - mean) / torch.sqrt(var + self.eps) + elif self.norm_type == "rms": # RMSNorm + rms = residual_out.pow(2).mean(dim=-1, keepdim=True) + normalized = residual_out / torch.sqrt(rms + self.eps) + # 3. apply affine transform if given + norm_weight, norm_bias = self.norm.weight, self.norm.bias + if norm_weight is not None and norm_bias is not None: + normalized = normalized * norm_weight + norm_bias + elif norm_weight is not None: + normalized = normalized * norm_weight + # 4. apply scale/shift if given + batch, seq_len, hidden_dim = x.shape + if scale.ndim <= 3: + if scale.ndim == 0 or (scale.ndim == 1 and scale.numel() == 1): + # (), (1) → (B, S, D) + scale = scale.expand(batch, seq_len, hidden_dim) + shift = shift.expand(batch, seq_len, hidden_dim) + elif scale.ndim == 2 and scale.shape in [ + (1, hidden_dim), + (batch, hidden_dim), + ]: + # (B, D) or (1, D) → (B, S, 1, D) + scale = scale[:, None, :].expand(batch, seq_len, hidden_dim) + shift = shift[:, None, :].expand(batch, seq_len, hidden_dim) + elif scale.ndim == 3 and scale.shape in [ + (batch, seq_len, hidden_dim), + (batch, 1, hidden_dim), + (1, seq_len, hidden_dim), + (1, 1, hidden_dim), + ]: + # (B, S, D), (B, 1, D), (1, S, D), (1, 1, D) → (B, S, 1, D) + scale = scale.expand(batch, seq_len, hidden_dim) + shift = shift.expand(batch, seq_len, hidden_dim) + normalized = normalized * (1.0 + scale) + shift + elif scale.ndim == 4 and scale.shape == (batch, scale.shape[1], 1, hidden_dim): + num_frames = scale.shape[1] + frame_seqlen = normalized.shape[1] // num_frames + normalized = ( + normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) + * (1.0 + scale) + + shift + ).flatten(1, 2) + return normalized, residual_out class LayerNormScaleShift(nn.Module): diff --git a/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py index 2789ebdf385d..03e13edda842 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py @@ -29,7 +29,7 @@ LayerNormScaleShift, RMSNorm, ScaleResidual, - ScaleResidualLayerNormScaleShift, + ScaleResidualNormScaleShift, ) from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP @@ -292,25 +292,24 @@ def __init__( print("QK Norm type not supported") raise Exception assert cross_attn_norm is True - self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.self_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=True, + bias=True, dtype=torch.float32, - compute_dtype=torch.float32, ) # 2. Cross-attention # Only T2V for now self.attn2 = WanT2VCrossAttention(dim, num_heads, qk_norm=qk_norm, eps=eps) - self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.cross_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=False, dtype=torch.float32, - compute_dtype=torch.float32, ) # 3. Feed-forward diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py index ad9e10dd5290..391583c87469 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -19,7 +19,7 @@ LayerNormScaleShift, RMSNorm, ScaleResidual, - ScaleResidualLayerNormScaleShift, + ScaleResidualNormScaleShift, ) from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP @@ -74,7 +74,7 @@ def __init__( self.img_attn_norm = LayerNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) - self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + self.img_attn_residual_mlp_norm = ScaleResidualNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) self.img_mlp_residual = ScaleResidual() @@ -120,7 +120,7 @@ def __init__( self.txt_attn_norm = LayerNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) - self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + self.txt_attn_residual_mlp_norm = ScaleResidualNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) self.txt_mlp_residual = ScaleResidual() diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index cb674e49195b..5a219e1c3be0 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -21,7 +21,7 @@ LayerNormScaleShift, RMSNorm, ScaleResidual, - ScaleResidualLayerNormScaleShift, + ScaleResidualNormScaleShift, ) from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP @@ -294,13 +294,13 @@ def __init__( logger.error("QK Norm type not supported") raise Exception assert cross_attn_norm is True - self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.self_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=True, + bias=True, dtype=torch.float32, - compute_dtype=torch.float32, ) # 2. Cross-attention @@ -322,13 +322,13 @@ def __init__( eps=eps, supported_attention_backends=supported_attention_backends, ) - self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.cross_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=False, + bias=False, dtype=torch.float32, - compute_dtype=torch.float32, ) # 3. Feed-forward @@ -469,13 +469,13 @@ def __init__( logger.error("QK Norm type not supported") raise Exception assert cross_attn_norm is True - self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.self_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=True, + bias=True, dtype=torch.float32, - compute_dtype=torch.float32, ) if AttentionBackendEnum.VIDEO_SPARSE_ATTN in supported_attention_backends: @@ -499,13 +499,13 @@ def __init__( eps=eps, supported_attention_backends=supported_attention_backends, ) - self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.cross_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=False, + bias=False, dtype=torch.float32, - compute_dtype=torch.float32, ) # 3. Feed-forward diff --git a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py new file mode 100644 index 000000000000..1be008a1cb2a --- /dev/null +++ b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py @@ -0,0 +1,192 @@ +import random +import unittest +from typing import Tuple + +import torch +from torch import Tensor + +from sglang.multimodal_gen.runtime.layers.layernorm import ScaleResidualNormScaleShift +from sglang.test.test_utils import CustomTestCase + + +def is_float(value): + try: + float(value) + return True + except ValueError: + return False + + +def allclose_with_tolerance( + x: Tensor, y: Tensor, atol: float, rtol: float, max_ratio=0.005 +) -> Tuple[bool, str]: + diff = torch.abs(x - y) + th = atol + rtol * torch.abs(y) + + # out-of-tolerance mask + bad_mask = diff > th + bad_ratio = bad_mask.float().mean().item() + + return bad_ratio <= max_ratio, f"{bad_ratio:.6f} > {max_ratio}" + + +################################################################################ +# Accuracy Test +################################################################################ + + +class TestScaleResidualNormScaleShiftAccuracy(CustomTestCase): + DTYPES = [torch.float32, torch.float16, torch.bfloat16] + PARAM_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + BATCH_NUM = [1, 2] + SEQ = [83, 1024, 2047, 32760] + FRAME = [4, 8, 16] + HIDDEN_SIZES = [83, 1024, 1338, 1536, 3072, 4096] + USE_AFFINE = [False, True] + USE_BIAS = [False, True] + NORM_TYPE = ["rms", "layer"] + GATE_SHAPE = ["1", "1D", "B1D", "BF1D"] + SCALE_SHIFT_SHAPE = ["1.2", "[1]", "BD", "1D", "BSD", "B1D", "1SD", "11D", "BF1D"] + SEEDS = [0] + + args = [ + (DTYPES, "dtype"), + (PARAM_DTYPES, "param_type"), + (BATCH_NUM, "batch"), + (SEQ, "seq"), + (FRAME, "frame"), + (HIDDEN_SIZES, "hidden_size"), + (USE_AFFINE, "use_affine"), + (USE_BIAS, "use_bias"), + (NORM_TYPE, "norm_type"), + (GATE_SHAPE, "gate_shape"), + (SCALE_SHIFT_SHAPE, "scale_shift_shape"), + (SEEDS, "seed"), + ] + + def gen_cases(self, num): + num_of_args = len(self.args) + + while num > 0: + yield { + self.args[i][1]: random.choice(self.args[i][0]) + for i in range(num_of_args) + } + num -= 1 + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_fused_test( + self, + **params, + ): + dtype = params["dtype"] + param_type = params["param_type"] + batch = params["batch"] + seq = params["seq"] + frame = params["frame"] + hidden_size = params["hidden_size"] + use_affine = params["use_affine"] + use_bias = params["use_bias"] + norm_type = params["norm_type"] + gate_shape = params["gate_shape"] + scale_shift_shape = params["scale_shift_shape"] + seed = params["seed"] + + torch.manual_seed(seed) + layer = ScaleResidualNormScaleShift( + hidden_size, + elementwise_affine=use_affine, + bias=use_bias, + dtype=param_type, + norm_type=norm_type, + ) + if use_affine: + w = torch.empty_like(layer.norm.weight) + w.normal_(mean=1.0, std=0.1) + layer.norm.weight.data.copy_(w) + if norm_type == "layer" and use_bias: + b = torch.empty_like(layer.norm.bias) + b.normal_(mean=1.0, std=0.1) + layer.norm.bias.data.copy_(b) + + residual = torch.randn(batch, seq, hidden_size, dtype=dtype) + x = torch.randn(batch, seq, hidden_size, dtype=dtype) + + if gate_shape == "1": + gate = 1 + elif gate_shape == "1D": + gate = torch.randn(1, hidden_size, dtype=dtype) + elif gate_shape == "BF1D": + if seq % frame != 0: + return + gate = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + elif gate_shape == "B1D": + gate = torch.randn(batch, 1, hidden_size, dtype=dtype) + else: + raise ValueError("Unknown gate shape.") + + if is_float(scale_shift_shape): + scale = torch.tensor(1.0, dtype=dtype) * float(scale_shift_shape) + shift = torch.tensor(1.0, dtype=dtype) * float(scale_shift_shape) + elif scale_shift_shape == "[1]": + scale = torch.ones(1, dtype=dtype) * torch.rand(1).item() + shift = torch.ones(1, dtype=dtype) * torch.rand(1).item() + elif scale_shift_shape == "BD": + scale = torch.randn(batch, hidden_size, dtype=dtype) + shift = torch.randn(batch, hidden_size, dtype=dtype) + elif scale_shift_shape == "1D": + scale = torch.randn(1, hidden_size, dtype=dtype) + shift = torch.randn(1, hidden_size, dtype=dtype) + elif scale_shift_shape == "BSD": + scale = torch.randn(batch, seq, hidden_size, dtype=dtype) + shift = torch.randn(batch, seq, hidden_size, dtype=dtype) + elif scale_shift_shape == "B1D": + scale = torch.randn(batch, 1, hidden_size, dtype=dtype) + shift = torch.randn(batch, 1, hidden_size, dtype=dtype) + elif scale_shift_shape == "1SD": + scale = torch.randn(1, seq, hidden_size, dtype=dtype) + shift = torch.randn(1, seq, hidden_size, dtype=dtype) + elif scale_shift_shape == "11D": + scale = torch.randn(1, 1, hidden_size, dtype=dtype) + shift = torch.randn(1, 1, hidden_size, dtype=dtype) + elif scale_shift_shape == "BF1D": + if seq % frame != 0: + return + scale = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + shift = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + + with torch.inference_mode(): + ref_out_mod, ref_out_resi = layer.forward_native( + residual, x, gate, shift, scale + ) + out_mod, out_resi = layer(residual, x, gate, shift, scale) + + if dtype == torch.float32 and param_type == torch.float32: + self.assertTrue( + *allclose_with_tolerance(out_mod, ref_out_mod, atol=1e-6, rtol=1e-4) + ) + self.assertTrue( + *allclose_with_tolerance(out_resi, ref_out_resi, atol=1e-6, rtol=1e-4) + ) + else: + self.assertTrue( + *allclose_with_tolerance(out_mod, ref_out_mod, atol=5e-2, rtol=1e-2) + ) + self.assertTrue( + *allclose_with_tolerance(out_resi, ref_out_resi, atol=5e-2, rtol=1e-2) + ) + + def test_fused(self): + for params in self.gen_cases(num=300): + with self.subTest(**params): + self._run_fused_test(**params) + torch.cuda.synchronize() + + +if __name__ == "__main__": + unittest.main(verbosity=2)