From 7c767d7c94695d95b6255772d8f24358357d66bf Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 16:09:04 +0000 Subject: [PATCH 1/8] merge conflict --- .../runtime/layers/layernorm.py | 116 ++++++ .../runtime/models/dits/causal_wanvideo.py | 9 +- .../runtime/models/dits/hunyuanvideo.py | 6 +- .../test_scale_residual_norm_scale_shift.py | 374 ++++++++++++++++++ 4 files changed, 497 insertions(+), 8 deletions(-) create mode 100644 python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index ec8f680a7876..855231f0164b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -269,6 +269,122 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ).to(origin_dtype) +@CustomOp.register("scale_residual_norm_scale_shift") +class ScaleResidualNormScaleShift(CustomOp): + + def __init__( + self, + hidden_size: int, + eps=1e-6, + norm_type: str = "rms", + elementwise_affine: bool = False, + bias: bool = False, + device=None, + dtype=None, + ): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm_type = norm_type.lower() + self.norm = nn.Module() + + factory_kwargs = {"device": device, "dtype": dtype} + if elementwise_affine: + self.norm.weight = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs) + ) + if self.norm_type == "layer" and bias: + self.norm.bias = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs) + ) + else: + self.norm.register_parameter("bias", None) + else: + self.norm.register_parameter("weight", None) + self.norm.register_parameter("bias", None) + + def forward_cuda( + self, + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | int, + shift: torch.Tensor, + scale: torch.Tensor, + ): + return torch.ops.sglang_ops.scale_residual_norm_scale_shift( + residual, x, + gate if isinstance(gate, torch.Tensor) else None, + self.norm.weight, self.norm.bias, + scale, shift, + self.eps, self.norm_type == "rms", + ) + + def forward_native( + self, + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | int, + shift: torch.Tensor, + scale: torch.Tensor, + ): + # 1. residual add + if isinstance(gate, torch.Tensor): + if gate.dim() == 4: + # gate.shape: [batch_size, num_frames, 1, inner_dim] + num_frames = gate.shape[1] + frame_seqlen = x.shape[1] // num_frames + residual_out = residual + ( + x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate + ).flatten(1, 2) + else: + # gate.shape: [batch_size, 1, inner_dim] + residual_out = residual + x * gate + else: + residual_out = residual + x * gate + # 2. normalize + if self.norm_type == "layer": # LayerNorm + mean = residual_out.mean(dim=-1, keepdim=True) + var = residual_out.var(dim=-1, unbiased=False, keepdim=True) + normalized = (residual_out - mean) / torch.sqrt(var + self.eps) + elif self.norm_type == "rms": # RMSNorm + rms = residual_out.pow(2).mean(dim=-1, keepdim=True) + normalized = residual_out / torch.sqrt(rms + self.eps) + # 3. apply affine transform if given + norm_weight, norm_bias = self.norm.weight, self.norm.bias + if norm_weight is not None and norm_bias is not None: + normalized = normalized * norm_weight + norm_bias + elif norm_weight is not None: + normalized = normalized * norm_weight + # 4. apply scale/shift if given + batch, seq_len, hidden_dim = x.shape + if scale.ndim <= 3: + if scale.ndim == 0 or (scale.ndim == 1 and scale.size == 1): + # (), (1) → (B, S, D) + scale = scale.expand(batch, seq_len, hidden_dim) + shift = shift.expand(batch, seq_len, hidden_dim) + elif scale.ndim == 2 and scale.shape in [(1, hidden_dim), (batch, hidden_dim)]: + # (B, D) or (1, D) → (B, S, 1, D) + scale = scale[:, None, :].expand(batch, seq_len, hidden_dim) + shift = shift[:, None, :].expand(batch, seq_len, hidden_dim) + elif scale.ndim == 3 and scale.shape in [ + (batch, seq_len, hidden_dim), + (batch, 1, hidden_dim), + (1, seq_len, hidden_dim), + (1, 1, hidden_dim), + ]: + # (B, S, D), (B, 1, D), (1, S, D), (1, 1, D) → (B, S, 1, D) + scale = scale.expand(batch, seq_len, hidden_dim) + shift = shift.expand(batch, seq_len, hidden_dim) + normalized = normalized * (1.0 + scale) + shift + elif scale.ndim == 4 and scale.shape == (batch, scale.shape[1], 1, hidden_dim): + num_frames = scale.shape[1] + frame_seqlen = normalized.shape[1] // num_frames + normalized = ( + normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) + * (1.0 + scale) + shift).flatten(1, 2) + return normalized, residual_out + + class ScaleResidualLayerNormScaleShift(nn.Module): """ Fused operation that combines: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py index 2789ebdf385d..03e13edda842 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py @@ -29,7 +29,7 @@ LayerNormScaleShift, RMSNorm, ScaleResidual, - ScaleResidualLayerNormScaleShift, + ScaleResidualNormScaleShift, ) from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP @@ -292,25 +292,24 @@ def __init__( print("QK Norm type not supported") raise Exception assert cross_attn_norm is True - self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.self_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=True, + bias=True, dtype=torch.float32, - compute_dtype=torch.float32, ) # 2. Cross-attention # Only T2V for now self.attn2 = WanT2VCrossAttention(dim, num_heads, qk_norm=qk_norm, eps=eps) - self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.cross_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=False, dtype=torch.float32, - compute_dtype=torch.float32, ) # 3. Feed-forward diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py index ad9e10dd5290..391583c87469 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -19,7 +19,7 @@ LayerNormScaleShift, RMSNorm, ScaleResidual, - ScaleResidualLayerNormScaleShift, + ScaleResidualNormScaleShift, ) from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP @@ -74,7 +74,7 @@ def __init__( self.img_attn_norm = LayerNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) - self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + self.img_attn_residual_mlp_norm = ScaleResidualNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) self.img_mlp_residual = ScaleResidual() @@ -120,7 +120,7 @@ def __init__( self.txt_attn_norm = LayerNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) - self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + self.txt_attn_residual_mlp_norm = ScaleResidualNormScaleShift( hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype ) self.txt_mlp_residual = ScaleResidual() diff --git a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py new file mode 100644 index 000000000000..d4fe98b7465a --- /dev/null +++ b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py @@ -0,0 +1,374 @@ + +from sglang.multimodal_gen.runtime.layers.triton_ops import fuse_scale_shift_kernel +from sglang.multimodal_gen.test.utils import allclose_with_tolerance +from sglang.test.test_utils import CustomTestCase +from sglang.multimodal_gen.runtime.layers.layernorm import ScaleResidualNormScaleShift +from sglang.multimodal_gen.runtime.layers.layernorm import ScaleResidualLayerNormScaleShift, LayerNorm +import torch.nn as nn +import random +import unittest +import itertools + +import torch +import time +import os +print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) +print("torch sees", torch.cuda.device_count(), "GPUs") +print("Using device:", torch.cuda.get_device_name(0)) + + +################################################################################ +# Benchmark +################################################################################ +def is_float(value): + try: + float(value) + return True + except ValueError: + return False + + +def benchmark(fn, *args, **kwargs): + warmup = 2 + iters = 20 + # Make sure everything is clean + with torch.inference_mode(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # --- Warmup --- + for _ in range(warmup): + fn(*args, **kwargs) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # --- Timed runs --- + t0 = time.time() + for _ in range(iters): + fn(*args, **kwargs) + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.time() + + # Average time per iteration (ms) + avg_time_ms = (t1 - t0) * 1000 / iters + return avg_time_ms + + +################################################################################ +# Accuracy Test +################################################################################ +class TestScaleResidualNormScaleShiftAccuracy(CustomTestCase): + DTYPES = [torch.float32, torch.float16, torch.bfloat16] + PARAM_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + BATCH_NUM = [1, 2] + SEQ = [83, 1024, 2047, 32760] + FRAME = [4, 8, 16] + HIDDEN_SIZES = [83, 1024, 1338, 1536, 3072, 4096] + USE_AFFINE = [False, True] + USE_BIAS = [False, True] + NORM_TYPE = ["rms", "layer"] + GATE_SHAPE = ["1", "1D", "B1D", "BF1D"] + SCALE_SHIFT_SHAPE = ["1.2", "[1]", "BD", + "1D", "BSD", "B1D", "1SD", "11D", "BF1D"] + SEEDS = [0] + + args = [ + (DTYPES, "dtype"), + (PARAM_DTYPES, "param_type"), + (BATCH_NUM, "batch"), + (SEQ, "seq"), + (FRAME, "frame"), + (HIDDEN_SIZES, "hidden_size"), + (USE_AFFINE, "use_affine"), + (USE_BIAS, "use_bias"), + (NORM_TYPE, "norm_type"), + (GATE_SHAPE, "gate_shape"), + (SCALE_SHIFT_SHAPE, "scale_shift_shape"), + (SEEDS, "seed"), + ] + + def gen_cases(self, num): + num_of_args = len(self.args) + + while num > 0: + yield { + self.args[i][1]: random.choice(self.args[i][0]) + for i in range(num_of_args) + } + num -= 1 + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_fused_test( + self, **params, + ): + dtype = params["dtype"] + param_type = params["param_type"] + batch = params["batch"] + seq = params["seq"] + frame = params["frame"] + hidden_size = params["hidden_size"] + use_affine = params["use_affine"] + use_bias = params["use_bias"] + norm_type = params["norm_type"] + gate_shape = params["gate_shape"] + scale_shift_shape = params["scale_shift_shape"] + seed = params["seed"] + + torch.manual_seed(seed) + layer = ScaleResidualNormScaleShift( + hidden_size, elementwise_affine=use_affine, bias=use_bias, dtype=param_type, norm_type=norm_type + ) + if use_affine: + w = torch.empty_like(layer.norm.weight) + w.normal_(mean=1.0, std=0.1) + layer.norm.weight.data.copy_(w) + if norm_type == "layer" and use_bias: + b = torch.empty_like(layer.norm.bias) + b.normal_(mean=1.0, std=0.1) + layer.norm.bias.data.copy_(b) + + residual = torch.randn(batch, seq, hidden_size, dtype=dtype) + x = torch.randn(batch, seq, hidden_size, dtype=dtype) + + if gate_shape == "1": + gate = 1 + elif gate_shape == "1D": + gate = torch.randn(1, hidden_size, dtype=dtype) + elif gate_shape == "BF1D": + if seq % frame != 0: + return + gate = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + elif gate_shape == "B1D": + gate = torch.randn(batch, 1, hidden_size, dtype=dtype) + else: + raise ValueError("Unknown gate shape.") + + if is_float(scale_shift_shape): + scale = torch.tensor(1.0, dtype=dtype) * float(scale_shift_shape) + shift = torch.tensor(1.0, dtype=dtype) * float(scale_shift_shape) + elif scale_shift_shape == "[1]": + scale = torch.ones(1, dtype=dtype) * torch.rand(1).item() + shift = torch.ones(1, dtype=dtype) * torch.rand(1).item() + elif scale_shift_shape == "BD": + scale = torch.randn(batch, hidden_size, dtype=dtype) + shift = torch.randn(batch, hidden_size, dtype=dtype) + elif scale_shift_shape == "1D": + scale = torch.randn(1, hidden_size, dtype=dtype) + shift = torch.randn(1, hidden_size, dtype=dtype) + elif scale_shift_shape == "BSD": + scale = torch.randn(batch, seq, hidden_size, dtype=dtype) + shift = torch.randn(batch, seq, hidden_size, dtype=dtype) + elif scale_shift_shape == "B1D": + scale = torch.randn(batch, 1, hidden_size, dtype=dtype) + shift = torch.randn(batch, 1, hidden_size, dtype=dtype) + elif scale_shift_shape == "1SD": + scale = torch.randn(1, seq, hidden_size, dtype=dtype) + shift = torch.randn(1, seq, hidden_size, dtype=dtype) + elif scale_shift_shape == "11D": + scale = torch.randn(1, 1, hidden_size, dtype=dtype) + shift = torch.randn(1, 1, hidden_size, dtype=dtype) + elif scale_shift_shape == "BF1D": + if seq % frame != 0: + return + scale = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + shift = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + + with torch.inference_mode(): + ref_out_mod, ref_out_resi = layer.forward_native( + residual, x, gate, shift, scale) + out_mod, out_resi = layer(residual, x, gate, shift, scale) + + if dtype == torch.float32 and param_type == torch.float32: + self.assertTrue(*allclose_with_tolerance( + out_mod, ref_out_mod, atol=1e-6, rtol=1e-4)) + self.assertTrue(*allclose_with_tolerance( + out_resi, ref_out_resi, atol=1e-6, rtol=1e-4)) + else: + self.assertTrue(*allclose_with_tolerance( + out_mod, ref_out_mod, atol=5e-2, rtol=1e-2)) + self.assertTrue(*allclose_with_tolerance( + out_resi, ref_out_resi, atol=5e-2, rtol=1e-2)) + + def test_fused(self): + for params in self.gen_cases(num=300): + with self.subTest(**params): + self._run_fused_test(**params) + torch.cuda.synchronize() + + +################################################################################ +# Performance Test +################################################################################ +class TestScaleResidualNormScaleShiftPerf(CustomTestCase): + DTYPES = [torch.float32, torch.float16] + PARAM_DTYPES = [torch.float32] + BATCH_NUM = [1, 2] + SEQ = [1024, 2048, 4096, 16380, 32760, 115200] + FRAME = [4] + HIDDEN_SIZES = [512, 1024, 1536, 2048, 3072] + USE_AFFINE = [True, False] + USE_BIAS = [True] + NORM_TYPE = ["layer"] + GATE_SHAPE = ["B1D"] + SCALE_SHIFT_SHAPE = ["11D"] + SEEDS = [0] + + args = [ + (DTYPES, "dtype"), + (PARAM_DTYPES, "param_type"), + (BATCH_NUM, "batch"), + (SEQ, "seq"), + (FRAME, "frame"), + (HIDDEN_SIZES, "hidden_size"), + (USE_AFFINE, "use_affine"), + (USE_BIAS, "use_bias"), + (NORM_TYPE, "norm_type"), + (GATE_SHAPE, "gate_shape"), + (SCALE_SHIFT_SHAPE, "scale_shift_shape"), + (SEEDS, "seed"), + ] + + def gen_cases(self): + keys = [arg[1] for arg in self.args] + value_lists = [arg[0] for arg in self.args] + + for values in itertools.product(*value_lists): + yield dict(zip(keys, values)) + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device(f"cuda") + + def _run_fused_test( + self, **params, + ): + dtype = params["dtype"] + param_type = params["param_type"] + batch = params["batch"] + seq = params["seq"] + frame = params["frame"] + hidden_size = params["hidden_size"] + use_affine = params["use_affine"] + use_bias = params["use_bias"] + norm_type = params["norm_type"] + gate_shape = params["gate_shape"] + scale_shift_shape = params["scale_shift_shape"] + seed = params["seed"] + + torch.manual_seed(seed) + + if norm_type == "layer" and use_bias == False: + return + + ref_layer = ScaleResidualLayerNormScaleShift( + hidden_size, elementwise_affine=use_affine, dtype=param_type, norm_type=norm_type + ) + layer = ScaleResidualNormScaleShift( + hidden_size, elementwise_affine=use_affine, bias=use_bias, dtype=param_type, norm_type=norm_type + ) + if use_affine: + w = torch.empty_like(layer.norm.weight) + w.normal_(mean=1.0, std=0.1) + layer.norm.weight.data.copy_(w) + ref_layer.norm.weight.data.copy_(w) + if norm_type == "layer" and use_bias: + b = torch.empty_like(layer.norm.bias) + b.normal_(mean=1.0, std=0.1) + layer.norm.bias.data.copy_(b) + ref_layer.norm.bias.data.copy_(b) + + residual = torch.randn(batch, seq, hidden_size, dtype=dtype) + x = torch.randn(batch, seq, hidden_size, dtype=dtype) + + if gate_shape == "1": + gate = 1 + elif gate_shape == "BF1D": + if seq % frame != 0: + return + gate = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + elif gate_shape == "B1D": + gate = torch.randn(batch, 1, hidden_size, dtype=dtype) + else: + raise ValueError("Unknown gate shape.") + + if is_float(scale_shift_shape): + scale = torch.ones(1, dtype=dtype) * float(scale_shift_shape) + shift = torch.ones(1, dtype=dtype) * float(scale_shift_shape) + elif scale_shift_shape == "[1]": + scale = torch.ones(1, dtype=dtype) * torch.rand(1).item() + shift = torch.ones(1, dtype=dtype) * torch.rand(1).item() + elif scale_shift_shape == "BD": + scale = torch.randn(batch, hidden_size, dtype=dtype) + shift = torch.randn(batch, hidden_size, dtype=dtype) + elif scale_shift_shape == "1D": + scale = torch.randn(1, hidden_size, dtype=dtype) + shift = torch.randn(1, hidden_size, dtype=dtype) + elif scale_shift_shape == "BSD": + scale = torch.randn(batch, seq, hidden_size, dtype=dtype) + shift = torch.randn(batch, seq, hidden_size, dtype=dtype) + elif scale_shift_shape == "B1D": + scale = torch.randn(batch, 1, hidden_size, dtype=dtype) + shift = torch.randn(batch, 1, hidden_size, dtype=dtype) + elif scale_shift_shape == "1SD": + scale = torch.randn(1, seq, hidden_size, dtype=dtype) + shift = torch.randn(1, seq, hidden_size, dtype=dtype) + elif scale_shift_shape == "11D": + scale = torch.randn(1, 1, hidden_size, dtype=dtype) + shift = torch.randn(1, 1, hidden_size, dtype=dtype) + elif scale_shift_shape == "BF1D": + if seq % frame != 0: + return + scale = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + shift = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) + + with torch.inference_mode(): + ref_out_mod, ref_out_resi = ref_layer( + residual, x, gate, shift, scale) + out_mod, out_resi = layer(residual, x, gate, shift, scale) + + if dtype == torch.float32 and param_type == torch.float32: + self.assertTrue(*allclose_with_tolerance( + out_mod, ref_out_mod, atol=1e-6, rtol=1e-4)) + self.assertTrue(*allclose_with_tolerance( + out_resi, ref_out_resi, atol=1e-6, rtol=1e-4)) + else: + self.assertTrue(*allclose_with_tolerance( + out_mod, ref_out_mod, atol=5e-2, rtol=1e-2)) + self.assertTrue(*allclose_with_tolerance( + out_resi, ref_out_resi, atol=5e-2, rtol=1e-2)) + + # Perf + fused_kernel_time = benchmark(layer, residual, x, gate, shift, scale) + naive_kernel_time = benchmark( + ref_layer, residual, x, gate, shift, scale) + speedup = naive_kernel_time / fused_kernel_time + print( + f"[speedup]={speedup:.2f}x ({naive_kernel_time:.3f}ms/{fused_kernel_time:.3f}ms), " + f"dtype={dtype}, param_type={param_type}, batch={batch}, seq={seq}, " + f"frame={frame}, hidden={hidden_size}, use_affine={use_affine}, " + f"use_bias={use_bias}, norm_type={norm_type}, gate_shape={gate_shape}, " + f"scale_shift_shape={scale_shift_shape}, seed={seed}" + ) + return speedup + + def test_fused(self): + speedup = [] + for params in self.gen_cases(): + with self.subTest(**params): + torch.cuda.synchronize() + speedup.append(self._run_fused_test(**params)) + torch.cuda.synchronize() + avg_speedup = sum(speedup) / len(speedup) + print(f"Average Speedup = {avg_speedup}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 7717b95cc438643103ffe783d37e7885c81032b8 Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 16:17:55 +0000 Subject: [PATCH 2/8] remote origin impl --- .../runtime/layers/layernorm.py | 96 -------- .../test_scale_residual_norm_scale_shift.py | 216 +----------------- 2 files changed, 10 insertions(+), 302 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 855231f0164b..63927db03230 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -385,102 +385,6 @@ def forward_native( return normalized, residual_out -class ScaleResidualLayerNormScaleShift(nn.Module): - """ - Fused operation that combines: - 1. Gated residual connection - 2. LayerNorm - 3. Scale and shift operations - - This reduces memory bandwidth by combining memory-bound operations. - """ - - def __init__( - self, - hidden_size: int, - norm_type: str = "rms", - eps: float = 1e-6, - elementwise_affine: bool = False, - dtype: torch.dtype = torch.float32, - compute_dtype: torch.dtype | None = None, - prefix: str = "", - ): - super().__init__() - if norm_type == "rms": - self.norm = RMSNorm( - hidden_size, has_weight=elementwise_affine, eps=eps, dtype=dtype - ) - elif norm_type == "layer": - if compute_dtype == torch.float32: - self.norm = FP32LayerNorm( - hidden_size, elementwise_affine=elementwise_affine, eps=eps - ) - else: - self.norm = LayerNorm( - hidden_size, - elementwise_affine=elementwise_affine, - eps=eps, - dtype=dtype, - ) - else: - raise NotImplementedError(f"Norm type {norm_type} not implemented") - - def forward( - self, - residual: torch.Tensor, - x: torch.Tensor, - gate: torch.Tensor | int, - shift: torch.Tensor, - scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply gated residual connection, followed by layernorm and - scale/shift in a single fused operation. - - Returns: - Tuple containing: - - normalized and modulated output of shape: [batch_size, seq_len, inner_dim] - - residual value (value after residual connection - but before normalization) - """ - # x.shape: [batch_size, seq_len, inner_dim] - # Apply residual connection with gating - if isinstance(gate, int): - # used by cross-attention, should be 1 - assert gate == 1 - residual_output = residual + x - elif isinstance(gate, torch.Tensor): - if gate.dim() == 4: - # gate.shape: [batch_size, num_frames, 1, inner_dim] - num_frames = gate.shape[1] - frame_seqlen = x.shape[1] // num_frames - residual_output = residual + ( - x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate - ).flatten(1, 2) - else: - # used by bidirectional self attention - # gate.shape: [batch_size, 1, inner_dim] - residual_output = residual + x * gate - else: - raise ValueError(f"Gate type {type(gate)} not supported") - # residual_output.shape: [batch_size, seq_len, inner_dim] - - # Apply normalization - normalized = self.norm(residual_output) - - # modulated = fused_scale_shift( - # normalized, - # scale, - # shift, - # ) - modulated = fuse_scale_shift_kernel( - normalized, - scale, - shift, - ) - return modulated, residual_output - - class LayerNormScaleShift(nn.Module): """ Fused operation that combines LayerNorm with scale and shift operations. diff --git a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py index d4fe98b7465a..7484c7706efd 100644 --- a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py +++ b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py @@ -1,25 +1,12 @@ -from sglang.multimodal_gen.runtime.layers.triton_ops import fuse_scale_shift_kernel -from sglang.multimodal_gen.test.utils import allclose_with_tolerance from sglang.test.test_utils import CustomTestCase from sglang.multimodal_gen.runtime.layers.layernorm import ScaleResidualNormScaleShift -from sglang.multimodal_gen.runtime.layers.layernorm import ScaleResidualLayerNormScaleShift, LayerNorm -import torch.nn as nn import random import unittest -import itertools - import torch -import time -import os -print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) -print("torch sees", torch.cuda.device_count(), "GPUs") -print("Using device:", torch.cuda.get_device_name(0)) +from torch import Tensor -################################################################################ -# Benchmark -################################################################################ def is_float(value): try: float(value) @@ -28,36 +15,21 @@ def is_float(value): return False -def benchmark(fn, *args, **kwargs): - warmup = 2 - iters = 20 - # Make sure everything is clean - with torch.inference_mode(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - # --- Warmup --- - for _ in range(warmup): - fn(*args, **kwargs) - if torch.cuda.is_available(): - torch.cuda.synchronize() - - # --- Timed runs --- - t0 = time.time() - for _ in range(iters): - fn(*args, **kwargs) - if torch.cuda.is_available(): - torch.cuda.synchronize() - t1 = time.time() +def allclose_with_tolerance(x: Tensor, y: Tensor, atol: float, rtol: float, max_ratio=0.005) -> Tuple[bool, str]: + diff = torch.abs(x - y) + th = atol + rtol * torch.abs(y) - # Average time per iteration (ms) - avg_time_ms = (t1 - t0) * 1000 / iters - return avg_time_ms + # out-of-tolerance mask + bad_mask = diff > th + bad_ratio = bad_mask.float().mean().item() + return bad_ratio <= max_ratio, f"{bad_ratio:.6f} > {max_ratio}" ################################################################################ # Accuracy Test ################################################################################ + + class TestScaleResidualNormScaleShiftAccuracy(CustomTestCase): DTYPES = [torch.float32, torch.float16, torch.bfloat16] PARAM_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -202,173 +174,5 @@ def test_fused(self): torch.cuda.synchronize() -################################################################################ -# Performance Test -################################################################################ -class TestScaleResidualNormScaleShiftPerf(CustomTestCase): - DTYPES = [torch.float32, torch.float16] - PARAM_DTYPES = [torch.float32] - BATCH_NUM = [1, 2] - SEQ = [1024, 2048, 4096, 16380, 32760, 115200] - FRAME = [4] - HIDDEN_SIZES = [512, 1024, 1536, 2048, 3072] - USE_AFFINE = [True, False] - USE_BIAS = [True] - NORM_TYPE = ["layer"] - GATE_SHAPE = ["B1D"] - SCALE_SHIFT_SHAPE = ["11D"] - SEEDS = [0] - - args = [ - (DTYPES, "dtype"), - (PARAM_DTYPES, "param_type"), - (BATCH_NUM, "batch"), - (SEQ, "seq"), - (FRAME, "frame"), - (HIDDEN_SIZES, "hidden_size"), - (USE_AFFINE, "use_affine"), - (USE_BIAS, "use_bias"), - (NORM_TYPE, "norm_type"), - (GATE_SHAPE, "gate_shape"), - (SCALE_SHIFT_SHAPE, "scale_shift_shape"), - (SEEDS, "seed"), - ] - - def gen_cases(self): - keys = [arg[1] for arg in self.args] - value_lists = [arg[0] for arg in self.args] - - for values in itertools.product(*value_lists): - yield dict(zip(keys, values)) - - @classmethod - def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is not available") - torch.set_default_device(f"cuda") - - def _run_fused_test( - self, **params, - ): - dtype = params["dtype"] - param_type = params["param_type"] - batch = params["batch"] - seq = params["seq"] - frame = params["frame"] - hidden_size = params["hidden_size"] - use_affine = params["use_affine"] - use_bias = params["use_bias"] - norm_type = params["norm_type"] - gate_shape = params["gate_shape"] - scale_shift_shape = params["scale_shift_shape"] - seed = params["seed"] - - torch.manual_seed(seed) - - if norm_type == "layer" and use_bias == False: - return - - ref_layer = ScaleResidualLayerNormScaleShift( - hidden_size, elementwise_affine=use_affine, dtype=param_type, norm_type=norm_type - ) - layer = ScaleResidualNormScaleShift( - hidden_size, elementwise_affine=use_affine, bias=use_bias, dtype=param_type, norm_type=norm_type - ) - if use_affine: - w = torch.empty_like(layer.norm.weight) - w.normal_(mean=1.0, std=0.1) - layer.norm.weight.data.copy_(w) - ref_layer.norm.weight.data.copy_(w) - if norm_type == "layer" and use_bias: - b = torch.empty_like(layer.norm.bias) - b.normal_(mean=1.0, std=0.1) - layer.norm.bias.data.copy_(b) - ref_layer.norm.bias.data.copy_(b) - - residual = torch.randn(batch, seq, hidden_size, dtype=dtype) - x = torch.randn(batch, seq, hidden_size, dtype=dtype) - - if gate_shape == "1": - gate = 1 - elif gate_shape == "BF1D": - if seq % frame != 0: - return - gate = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) - elif gate_shape == "B1D": - gate = torch.randn(batch, 1, hidden_size, dtype=dtype) - else: - raise ValueError("Unknown gate shape.") - - if is_float(scale_shift_shape): - scale = torch.ones(1, dtype=dtype) * float(scale_shift_shape) - shift = torch.ones(1, dtype=dtype) * float(scale_shift_shape) - elif scale_shift_shape == "[1]": - scale = torch.ones(1, dtype=dtype) * torch.rand(1).item() - shift = torch.ones(1, dtype=dtype) * torch.rand(1).item() - elif scale_shift_shape == "BD": - scale = torch.randn(batch, hidden_size, dtype=dtype) - shift = torch.randn(batch, hidden_size, dtype=dtype) - elif scale_shift_shape == "1D": - scale = torch.randn(1, hidden_size, dtype=dtype) - shift = torch.randn(1, hidden_size, dtype=dtype) - elif scale_shift_shape == "BSD": - scale = torch.randn(batch, seq, hidden_size, dtype=dtype) - shift = torch.randn(batch, seq, hidden_size, dtype=dtype) - elif scale_shift_shape == "B1D": - scale = torch.randn(batch, 1, hidden_size, dtype=dtype) - shift = torch.randn(batch, 1, hidden_size, dtype=dtype) - elif scale_shift_shape == "1SD": - scale = torch.randn(1, seq, hidden_size, dtype=dtype) - shift = torch.randn(1, seq, hidden_size, dtype=dtype) - elif scale_shift_shape == "11D": - scale = torch.randn(1, 1, hidden_size, dtype=dtype) - shift = torch.randn(1, 1, hidden_size, dtype=dtype) - elif scale_shift_shape == "BF1D": - if seq % frame != 0: - return - scale = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) - shift = torch.randn(batch, frame, 1, hidden_size, dtype=dtype) - - with torch.inference_mode(): - ref_out_mod, ref_out_resi = ref_layer( - residual, x, gate, shift, scale) - out_mod, out_resi = layer(residual, x, gate, shift, scale) - - if dtype == torch.float32 and param_type == torch.float32: - self.assertTrue(*allclose_with_tolerance( - out_mod, ref_out_mod, atol=1e-6, rtol=1e-4)) - self.assertTrue(*allclose_with_tolerance( - out_resi, ref_out_resi, atol=1e-6, rtol=1e-4)) - else: - self.assertTrue(*allclose_with_tolerance( - out_mod, ref_out_mod, atol=5e-2, rtol=1e-2)) - self.assertTrue(*allclose_with_tolerance( - out_resi, ref_out_resi, atol=5e-2, rtol=1e-2)) - - # Perf - fused_kernel_time = benchmark(layer, residual, x, gate, shift, scale) - naive_kernel_time = benchmark( - ref_layer, residual, x, gate, shift, scale) - speedup = naive_kernel_time / fused_kernel_time - print( - f"[speedup]={speedup:.2f}x ({naive_kernel_time:.3f}ms/{fused_kernel_time:.3f}ms), " - f"dtype={dtype}, param_type={param_type}, batch={batch}, seq={seq}, " - f"frame={frame}, hidden={hidden_size}, use_affine={use_affine}, " - f"use_bias={use_bias}, norm_type={norm_type}, gate_shape={gate_shape}, " - f"scale_shift_shape={scale_shift_shape}, seed={seed}" - ) - return speedup - - def test_fused(self): - speedup = [] - for params in self.gen_cases(): - with self.subTest(**params): - torch.cuda.synchronize() - speedup.append(self._run_fused_test(**params)) - torch.cuda.synchronize() - avg_speedup = sum(speedup) / len(speedup) - print(f"Average Speedup = {avg_speedup}") - - if __name__ == "__main__": unittest.main(verbosity=2) From 01f06e426485fc107e5b010544e9c6f30463243d Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 16:22:27 +0000 Subject: [PATCH 3/8] wan model --- .../runtime/models/dits/wanvideo.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index cb674e49195b..f0f7b44928c7 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -21,7 +21,7 @@ LayerNormScaleShift, RMSNorm, ScaleResidual, - ScaleResidualLayerNormScaleShift, + ScaleResidualNormScaleShift, ) from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP @@ -294,13 +294,13 @@ def __init__( logger.error("QK Norm type not supported") raise Exception assert cross_attn_norm is True - self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.self_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=True, + bias=True, dtype=torch.float32, - compute_dtype=torch.float32, ) # 2. Cross-attention @@ -322,13 +322,13 @@ def __init__( eps=eps, supported_attention_backends=supported_attention_backends, ) - self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.cross_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=False, + bias=False, dtype=torch.float32, - compute_dtype=torch.float32, ) # 3. Feed-forward @@ -469,13 +469,13 @@ def __init__( logger.error("QK Norm type not supported") raise Exception assert cross_attn_norm is True - self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.cross_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, - elementwise_affine=True, + elementwise_affine=False, + bias=False, dtype=torch.float32, - compute_dtype=torch.float32, ) if AttentionBackendEnum.VIDEO_SPARSE_ATTN in supported_attention_backends: @@ -499,13 +499,13 @@ def __init__( eps=eps, supported_attention_backends=supported_attention_backends, ) - self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + self.cross_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, elementwise_affine=False, + bias=False, dtype=torch.float32, - compute_dtype=torch.float32, ) # 3. Feed-forward From 9dee21421a44095a4bf4f1ca92566601f6b9d313 Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 17:04:43 +0000 Subject: [PATCH 4/8] bug fix for wanvideo 3d --- python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index f0f7b44928c7..8f36b052018c 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -469,7 +469,7 @@ def __init__( logger.error("QK Norm type not supported") raise Exception assert cross_attn_norm is True - self.cross_attn_residual_norm = ScaleResidualNormScaleShift( + self.self_attn_residual_norm = ScaleResidualNormScaleShift( dim, norm_type="layer", eps=eps, From 038ce1de5518def46e5a67aa9a5f7e0a9eb14ed3 Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 17:08:22 +0000 Subject: [PATCH 5/8] fix --- .../multimodal_gen/test/test_scale_residual_norm_scale_shift.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py index 7484c7706efd..86155bce2579 100644 --- a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py +++ b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py @@ -5,6 +5,7 @@ import unittest import torch from torch import Tensor +from typing import Tuple def is_float(value): From 96426899f67a75df295aeb41d853cd5cd362fb24 Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 17:11:57 +0000 Subject: [PATCH 6/8] fix --- python/sglang/multimodal_gen/runtime/layers/layernorm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 63927db03230..982bdcd2e1df 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -31,6 +31,7 @@ _is_xpu = is_xpu() from sgl_kernel import fused_add_rmsnorm, rmsnorm +from sgl_kernel import scale_residual_norm_scale_shit # Copied and adapted from sglang @@ -311,12 +312,12 @@ def forward_cuda( shift: torch.Tensor, scale: torch.Tensor, ): - return torch.ops.sglang_ops.scale_residual_norm_scale_shift( + scale_residual_norm_scale_shit( residual, x, gate if isinstance(gate, torch.Tensor) else None, self.norm.weight, self.norm.bias, scale, shift, - self.eps, self.norm_type == "rms", + self.eps, self.norm_type, ) def forward_native( @@ -358,7 +359,7 @@ def forward_native( # 4. apply scale/shift if given batch, seq_len, hidden_dim = x.shape if scale.ndim <= 3: - if scale.ndim == 0 or (scale.ndim == 1 and scale.size == 1): + if scale.ndim == 0 or (scale.ndim == 1 and scale.numel() == 1): # (), (1) → (B, S, D) scale = scale.expand(batch, seq_len, hidden_dim) shift = shift.expand(batch, seq_len, hidden_dim) From 46e16fd5d76da0fdb2e4dd6e06044891b7a7b01b Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 17:18:49 +0000 Subject: [PATCH 7/8] code-format --- .../runtime/layers/layernorm.py | 34 +++++++------ .../test_scale_residual_norm_scale_shift.py | 49 ++++++++++++------- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 982bdcd2e1df..f3170d4147b0 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -10,11 +10,7 @@ import torch.nn.functional as F from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp -from sglang.multimodal_gen.runtime.layers.triton_ops import ( - fuse_scale_shift_kernel, - norm_infer, - rms_norm_fn, -) +from sglang.multimodal_gen.runtime.layers.triton_ops import norm_infer, rms_norm_fn from sglang.multimodal_gen.runtime.utils.common import ( get_bool_env_var, is_cpu, @@ -30,8 +26,7 @@ _is_cpu = is_cpu() _is_xpu = is_xpu() -from sgl_kernel import fused_add_rmsnorm, rmsnorm -from sgl_kernel import scale_residual_norm_scale_shit +from sgl_kernel import fused_add_rmsnorm, rmsnorm, scale_residual_norm_scale_shit # Copied and adapted from sglang @@ -313,11 +308,15 @@ def forward_cuda( scale: torch.Tensor, ): scale_residual_norm_scale_shit( - residual, x, + residual, + x, gate if isinstance(gate, torch.Tensor) else None, - self.norm.weight, self.norm.bias, - scale, shift, - self.eps, self.norm_type, + self.norm.weight, + self.norm.bias, + scale, + shift, + self.eps, + self.norm_type, ) def forward_native( @@ -343,11 +342,11 @@ def forward_native( else: residual_out = residual + x * gate # 2. normalize - if self.norm_type == "layer": # LayerNorm + if self.norm_type == "layer": # LayerNorm mean = residual_out.mean(dim=-1, keepdim=True) var = residual_out.var(dim=-1, unbiased=False, keepdim=True) normalized = (residual_out - mean) / torch.sqrt(var + self.eps) - elif self.norm_type == "rms": # RMSNorm + elif self.norm_type == "rms": # RMSNorm rms = residual_out.pow(2).mean(dim=-1, keepdim=True) normalized = residual_out / torch.sqrt(rms + self.eps) # 3. apply affine transform if given @@ -363,7 +362,10 @@ def forward_native( # (), (1) → (B, S, D) scale = scale.expand(batch, seq_len, hidden_dim) shift = shift.expand(batch, seq_len, hidden_dim) - elif scale.ndim == 2 and scale.shape in [(1, hidden_dim), (batch, hidden_dim)]: + elif scale.ndim == 2 and scale.shape in [ + (1, hidden_dim), + (batch, hidden_dim), + ]: # (B, D) or (1, D) → (B, S, 1, D) scale = scale[:, None, :].expand(batch, seq_len, hidden_dim) shift = shift[:, None, :].expand(batch, seq_len, hidden_dim) @@ -382,7 +384,9 @@ def forward_native( frame_seqlen = normalized.shape[1] // num_frames normalized = ( normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) - * (1.0 + scale) + shift).flatten(1, 2) + * (1.0 + scale) + + shift + ).flatten(1, 2) return normalized, residual_out diff --git a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py index 86155bce2579..1be008a1cb2a 100644 --- a/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py +++ b/python/sglang/multimodal_gen/test/test_scale_residual_norm_scale_shift.py @@ -1,11 +1,12 @@ - -from sglang.test.test_utils import CustomTestCase -from sglang.multimodal_gen.runtime.layers.layernorm import ScaleResidualNormScaleShift import random import unittest +from typing import Tuple + import torch from torch import Tensor -from typing import Tuple + +from sglang.multimodal_gen.runtime.layers.layernorm import ScaleResidualNormScaleShift +from sglang.test.test_utils import CustomTestCase def is_float(value): @@ -16,7 +17,9 @@ def is_float(value): return False -def allclose_with_tolerance(x: Tensor, y: Tensor, atol: float, rtol: float, max_ratio=0.005) -> Tuple[bool, str]: +def allclose_with_tolerance( + x: Tensor, y: Tensor, atol: float, rtol: float, max_ratio=0.005 +) -> Tuple[bool, str]: diff = torch.abs(x - y) th = atol + rtol * torch.abs(y) @@ -26,6 +29,7 @@ def allclose_with_tolerance(x: Tensor, y: Tensor, atol: float, rtol: float, max_ return bad_ratio <= max_ratio, f"{bad_ratio:.6f} > {max_ratio}" + ################################################################################ # Accuracy Test ################################################################################ @@ -42,8 +46,7 @@ class TestScaleResidualNormScaleShiftAccuracy(CustomTestCase): USE_BIAS = [False, True] NORM_TYPE = ["rms", "layer"] GATE_SHAPE = ["1", "1D", "B1D", "BF1D"] - SCALE_SHIFT_SHAPE = ["1.2", "[1]", "BD", - "1D", "BSD", "B1D", "1SD", "11D", "BF1D"] + SCALE_SHIFT_SHAPE = ["1.2", "[1]", "BD", "1D", "BSD", "B1D", "1SD", "11D", "BF1D"] SEEDS = [0] args = [ @@ -78,7 +81,8 @@ def setUpClass(cls): torch.set_default_device("cuda") def _run_fused_test( - self, **params, + self, + **params, ): dtype = params["dtype"] param_type = params["param_type"] @@ -95,7 +99,11 @@ def _run_fused_test( torch.manual_seed(seed) layer = ScaleResidualNormScaleShift( - hidden_size, elementwise_affine=use_affine, bias=use_bias, dtype=param_type, norm_type=norm_type + hidden_size, + elementwise_affine=use_affine, + bias=use_bias, + dtype=param_type, + norm_type=norm_type, ) if use_affine: w = torch.empty_like(layer.norm.weight) @@ -154,19 +162,24 @@ def _run_fused_test( with torch.inference_mode(): ref_out_mod, ref_out_resi = layer.forward_native( - residual, x, gate, shift, scale) + residual, x, gate, shift, scale + ) out_mod, out_resi = layer(residual, x, gate, shift, scale) if dtype == torch.float32 and param_type == torch.float32: - self.assertTrue(*allclose_with_tolerance( - out_mod, ref_out_mod, atol=1e-6, rtol=1e-4)) - self.assertTrue(*allclose_with_tolerance( - out_resi, ref_out_resi, atol=1e-6, rtol=1e-4)) + self.assertTrue( + *allclose_with_tolerance(out_mod, ref_out_mod, atol=1e-6, rtol=1e-4) + ) + self.assertTrue( + *allclose_with_tolerance(out_resi, ref_out_resi, atol=1e-6, rtol=1e-4) + ) else: - self.assertTrue(*allclose_with_tolerance( - out_mod, ref_out_mod, atol=5e-2, rtol=1e-2)) - self.assertTrue(*allclose_with_tolerance( - out_resi, ref_out_resi, atol=5e-2, rtol=1e-2)) + self.assertTrue( + *allclose_with_tolerance(out_mod, ref_out_mod, atol=5e-2, rtol=1e-2) + ) + self.assertTrue( + *allclose_with_tolerance(out_resi, ref_out_resi, atol=5e-2, rtol=1e-2) + ) def test_fused(self): for params in self.gen_cases(num=300): From bed8c125931127ed875c071d8641631c730ac15d Mon Sep 17 00:00:00 2001 From: Yihan Chen Date: Thu, 4 Dec 2025 17:19:28 +0000 Subject: [PATCH 8/8] fix --- python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 8f36b052018c..5a219e1c3be0 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -473,8 +473,8 @@ def __init__( dim, norm_type="layer", eps=eps, - elementwise_affine=False, - bias=False, + elementwise_affine=True, + bias=True, dtype=torch.float32, )