Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 97 additions & 72 deletions python/sglang/multimodal_gen/runtime/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
import torch.nn.functional as F

from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp
from sglang.multimodal_gen.runtime.layers.triton_ops import (
fuse_scale_shift_kernel,
norm_infer,
rms_norm_fn,
)
from sglang.multimodal_gen.runtime.layers.triton_ops import norm_infer, rms_norm_fn
from sglang.multimodal_gen.runtime.utils.common import (
get_bool_env_var,
is_cpu,
Expand All @@ -30,7 +26,7 @@
_is_cpu = is_cpu()
_is_xpu = is_xpu()

from sgl_kernel import fused_add_rmsnorm, rmsnorm
from sgl_kernel import fused_add_rmsnorm, rmsnorm, scale_residual_norm_scale_shit


# Copied and adapted from sglang
Expand Down Expand Up @@ -269,100 +265,129 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
).to(origin_dtype)


class ScaleResidualLayerNormScaleShift(nn.Module):
"""
Fused operation that combines:
1. Gated residual connection
2. LayerNorm
3. Scale and shift operations

This reduces memory bandwidth by combining memory-bound operations.
"""
@CustomOp.register("scale_residual_norm_scale_shift")
class ScaleResidualNormScaleShift(CustomOp):

def __init__(
self,
hidden_size: int,
eps=1e-6,
norm_type: str = "rms",
eps: float = 1e-6,
elementwise_affine: bool = False,
dtype: torch.dtype = torch.float32,
compute_dtype: torch.dtype | None = None,
prefix: str = "",
bias: bool = False,
device=None,
dtype=None,
):
super().__init__()
if norm_type == "rms":
self.norm = RMSNorm(
hidden_size, has_weight=elementwise_affine, eps=eps, dtype=dtype
self.hidden_size = hidden_size
self.eps = eps
self.norm_type = norm_type.lower()
self.norm = nn.Module()

factory_kwargs = {"device": device, "dtype": dtype}
if elementwise_affine:
self.norm.weight = torch.nn.Parameter(
torch.empty(hidden_size, **factory_kwargs)
)
elif norm_type == "layer":
if compute_dtype == torch.float32:
self.norm = FP32LayerNorm(
hidden_size, elementwise_affine=elementwise_affine, eps=eps
if self.norm_type == "layer" and bias:
self.norm.bias = torch.nn.Parameter(
torch.empty(hidden_size, **factory_kwargs)
)
else:
self.norm = LayerNorm(
hidden_size,
elementwise_affine=elementwise_affine,
eps=eps,
dtype=dtype,
)
self.norm.register_parameter("bias", None)
else:
raise NotImplementedError(f"Norm type {norm_type} not implemented")
self.norm.register_parameter("weight", None)
self.norm.register_parameter("bias", None)

def forward(
def forward_cuda(
self,
residual: torch.Tensor,
x: torch.Tensor,
gate: torch.Tensor | int,
shift: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply gated residual connection, followed by layernorm and
scale/shift in a single fused operation.

Returns:
Tuple containing:
- normalized and modulated output of shape: [batch_size, seq_len, inner_dim]
- residual value (value after residual connection
but before normalization)
"""
# x.shape: [batch_size, seq_len, inner_dim]
# Apply residual connection with gating
if isinstance(gate, int):
# used by cross-attention, should be 1
assert gate == 1
residual_output = residual + x
elif isinstance(gate, torch.Tensor):
):
scale_residual_norm_scale_shit(
residual,
x,
gate if isinstance(gate, torch.Tensor) else None,
self.norm.weight,
self.norm.bias,
scale,
shift,
self.eps,
self.norm_type,
)

def forward_native(
self,
residual: torch.Tensor,
x: torch.Tensor,
gate: torch.Tensor | int,
shift: torch.Tensor,
scale: torch.Tensor,
):
# 1. residual add
if isinstance(gate, torch.Tensor):
if gate.dim() == 4:
# gate.shape: [batch_size, num_frames, 1, inner_dim]
num_frames = gate.shape[1]
frame_seqlen = x.shape[1] // num_frames
residual_output = residual + (
residual_out = residual + (
x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate
).flatten(1, 2)
else:
# used by bidirectional self attention
# gate.shape: [batch_size, 1, inner_dim]
residual_output = residual + x * gate
residual_out = residual + x * gate
else:
raise ValueError(f"Gate type {type(gate)} not supported")
# residual_output.shape: [batch_size, seq_len, inner_dim]

# Apply normalization
normalized = self.norm(residual_output)

# modulated = fused_scale_shift(
# normalized,
# scale,
# shift,
# )
modulated = fuse_scale_shift_kernel(
normalized,
scale,
shift,
)
return modulated, residual_output
residual_out = residual + x * gate
# 2. normalize
if self.norm_type == "layer": # LayerNorm
mean = residual_out.mean(dim=-1, keepdim=True)
var = residual_out.var(dim=-1, unbiased=False, keepdim=True)
normalized = (residual_out - mean) / torch.sqrt(var + self.eps)
elif self.norm_type == "rms": # RMSNorm
rms = residual_out.pow(2).mean(dim=-1, keepdim=True)
normalized = residual_out / torch.sqrt(rms + self.eps)
# 3. apply affine transform if given
norm_weight, norm_bias = self.norm.weight, self.norm.bias
if norm_weight is not None and norm_bias is not None:
normalized = normalized * norm_weight + norm_bias
elif norm_weight is not None:
normalized = normalized * norm_weight
# 4. apply scale/shift if given
batch, seq_len, hidden_dim = x.shape
if scale.ndim <= 3:
if scale.ndim == 0 or (scale.ndim == 1 and scale.numel() == 1):
# (), (1) → (B, S, D)
scale = scale.expand(batch, seq_len, hidden_dim)
shift = shift.expand(batch, seq_len, hidden_dim)
elif scale.ndim == 2 and scale.shape in [
(1, hidden_dim),
(batch, hidden_dim),
]:
# (B, D) or (1, D) → (B, S, 1, D)
scale = scale[:, None, :].expand(batch, seq_len, hidden_dim)
shift = shift[:, None, :].expand(batch, seq_len, hidden_dim)
elif scale.ndim == 3 and scale.shape in [
(batch, seq_len, hidden_dim),
(batch, 1, hidden_dim),
(1, seq_len, hidden_dim),
(1, 1, hidden_dim),
]:
# (B, S, D), (B, 1, D), (1, S, D), (1, 1, D) → (B, S, 1, D)
scale = scale.expand(batch, seq_len, hidden_dim)
shift = shift.expand(batch, seq_len, hidden_dim)
normalized = normalized * (1.0 + scale) + shift
elif scale.ndim == 4 and scale.shape == (batch, scale.shape[1], 1, hidden_dim):
num_frames = scale.shape[1]
frame_seqlen = normalized.shape[1] // num_frames
normalized = (
normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen))
* (1.0 + scale)
+ shift
).flatten(1, 2)
return normalized, residual_out


class LayerNormScaleShift(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
LayerNormScaleShift,
RMSNorm,
ScaleResidual,
ScaleResidualLayerNormScaleShift,
ScaleResidualNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
Expand Down Expand Up @@ -292,25 +292,24 @@ def __init__(
print("QK Norm type not supported")
raise Exception
assert cross_attn_norm is True
self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(
self.self_attn_residual_norm = ScaleResidualNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=True,
bias=True,
dtype=torch.float32,
compute_dtype=torch.float32,
)

# 2. Cross-attention
# Only T2V for now
self.attn2 = WanT2VCrossAttention(dim, num_heads, qk_norm=qk_norm, eps=eps)
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
self.cross_attn_residual_norm = ScaleResidualNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)

# 3. Feed-forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
LayerNormScaleShift,
RMSNorm,
ScaleResidual,
ScaleResidualLayerNormScaleShift,
ScaleResidualNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self.img_attn_norm = LayerNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(
self.img_attn_residual_mlp_norm = ScaleResidualNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.img_mlp_residual = ScaleResidual()
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
self.txt_attn_norm = LayerNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(
self.txt_attn_residual_mlp_norm = ScaleResidualNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.txt_mlp_residual = ScaleResidual()
Expand Down
18 changes: 9 additions & 9 deletions python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LayerNormScaleShift,
RMSNorm,
ScaleResidual,
ScaleResidualLayerNormScaleShift,
ScaleResidualNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
Expand Down Expand Up @@ -294,13 +294,13 @@ def __init__(
logger.error("QK Norm type not supported")
raise Exception
assert cross_attn_norm is True
self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(
self.self_attn_residual_norm = ScaleResidualNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=True,
bias=True,
dtype=torch.float32,
compute_dtype=torch.float32,
)

# 2. Cross-attention
Expand All @@ -322,13 +322,13 @@ def __init__(
eps=eps,
supported_attention_backends=supported_attention_backends,
)
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
self.cross_attn_residual_norm = ScaleResidualNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=False,
bias=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)

# 3. Feed-forward
Expand Down Expand Up @@ -469,13 +469,13 @@ def __init__(
logger.error("QK Norm type not supported")
raise Exception
assert cross_attn_norm is True
self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(
self.self_attn_residual_norm = ScaleResidualNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=True,
bias=True,
dtype=torch.float32,
compute_dtype=torch.float32,
)

if AttentionBackendEnum.VIDEO_SPARSE_ATTN in supported_attention_backends:
Expand All @@ -499,13 +499,13 @@ def __init__(
eps=eps,
supported_attention_backends=supported_attention_backends,
)
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
self.cross_attn_residual_norm = ScaleResidualNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=False,
bias=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)

# 3. Feed-forward
Expand Down
Loading
Loading