Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/sglang/multimodal_gen/runtime/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
fuse_scale_shift_kernel,
norm_infer,
rms_norm_fn,
triton_one_pass_rms_norm,
)
from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var

Expand Down Expand Up @@ -76,7 +77,12 @@ def forward_cuda(
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x.view(shape), residual.view(residual_shape)
else:
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
if x.shape[-1] <= 128:
out = triton_one_pass_rms_norm(
x, self.weight.data, self.variance_epsilon
)
else:
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
out = out.view(shape)
return out

Expand Down
55 changes: 55 additions & 0 deletions python/sglang/multimodal_gen/runtime/layers/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,3 +1106,58 @@ def rms_norm_fn(
out,
residual_out,
)


# Adapted from https://github.com/ModelTC/LightX2V/blob/main/lightx2v/common/ops/norm/triton_ops.py#L905-L956
@triton.jit
def _rms_norm_tiled_onepass(
y_ptr,
x_ptr,
w_ptr,
SEQ: tl.constexpr,
DIM: tl.constexpr,
EPS: tl.constexpr,
BLOCK_SIZE_SEQ: tl.constexpr,
BLOCK_SIZE_DIM: tl.constexpr,
):
seq_blk_id = tl.program_id(0)
seq_id = seq_blk_id * BLOCK_SIZE_SEQ

seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None]
s_mask = seq_offset < SEQ
d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :]
d_mask = d_offset < DIM
y_blk = y_ptr + seq_offset * DIM + d_offset
x_blk = x_ptr + seq_offset * DIM + d_offset
mask = s_mask & d_mask

x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32)
mean_square = tl.sum(x * x, axis=1, keep_dims=True) / DIM
rstd = tl.math.rsqrt(mean_square + EPS)
w = tl.load(w_ptr + d_offset, mask=d_mask)
tl.store(y_blk, x * rstd * w, mask=mask)


def triton_one_pass_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
shape = x.shape
x = x.contiguous()
y = torch.empty_like(x)
x_view = x.reshape(-1, shape[-1])
y_view = y.reshape(-1, shape[-1])
S, D = x_view.shape

BLOCK_SIZE_SEQ = min(16, triton.next_power_of_2(max(1, S // 512)))
grid = (triton.cdiv(S, BLOCK_SIZE_SEQ),)

with torch.cuda.device(x.device):
torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid](
y_view,
x_view,
w,
S,
D,
eps,
BLOCK_SIZE_DIM=triton.next_power_of_2(D),
BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ,
)
return y
42 changes: 37 additions & 5 deletions python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.normalization import AdaLayerNormContinuous

from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm
from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig
from sglang.multimodal_gen.runtime.layers.attention import USPAttention
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm
from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
NDRotaryEmbedding,
Expand Down Expand Up @@ -196,16 +197,47 @@ def forward(
key = key.unflatten(-1, (self.heads, -1))
value = value.unflatten(-1, (self.heads, -1))

query = self.norm_q(query)
key = self.norm_k(key)
if (
query.is_cuda
and (self.norm_q.variance_epsilon == self.norm_k.variance_epsilon)
and can_use_fused_inplace_qknorm(self.head_dim, query.dtype)
):
query, key = apply_qk_norm(
q=query,
k=key,
q_norm=self.norm_q,
k_norm=self.norm_k,
head_dim=self.head_dim,
allow_inplace=True,
)
else:
query = self.norm_q(query)
key = self.norm_k(key)

if self.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
encoder_key = encoder_key.unflatten(-1, (self.heads, -1))
encoder_value = encoder_value.unflatten(-1, (self.heads, -1))

encoder_query = self.norm_added_q(encoder_query)
encoder_key = self.norm_added_k(encoder_key)
if (
encoder_query.is_cuda
and (
self.norm_added_q.variance_epsilon
== self.norm_added_k.variance_epsilon
)
and can_use_fused_inplace_qknorm(self.head_dim, encoder_query.dtype)
):
encoder_query, encoder_key = apply_qk_norm(
q=encoder_query,
k=encoder_key,
q_norm=self.norm_added_q,
k_norm=self.norm_added_k,
head_dim=self.head_dim,
allow_inplace=True,
)
else:
encoder_query = self.norm_added_q(encoder_query)
encoder_key = self.norm_added_k(encoder_key)

query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
Expand Down
Loading