Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/sglang/multimodal_gen/runtime/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
fuse_scale_shift_kernel,
norm_infer,
rms_norm_fn,
triton_one_pass_rms_norm,
)
from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var

Expand Down Expand Up @@ -76,7 +77,12 @@ def forward_cuda(
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x.view(shape), residual.view(residual_shape)
else:
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
if x.shape[-1] <= 128:
Comment thread
BBuf marked this conversation as resolved.
out = triton_one_pass_rms_norm(
x, self.weight.data, self.variance_epsilon
)
else:
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
out = out.view(shape)
return out

Expand Down
55 changes: 55 additions & 0 deletions python/sglang/multimodal_gen/runtime/layers/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,3 +1106,58 @@ def rms_norm_fn(
out,
residual_out,
)


# Adapted from https://github.com/ModelTC/LightX2V/blob/main/lightx2v/common/ops/norm/triton_ops.py#L905-L956
@triton.jit
def _rms_norm_tiled_onepass(
y_ptr,
x_ptr,
w_ptr,
SEQ: tl.constexpr,
DIM: tl.constexpr,
EPS: tl.constexpr,
BLOCK_SIZE_SEQ: tl.constexpr,
BLOCK_SIZE_DIM: tl.constexpr,
):
seq_blk_id = tl.program_id(0)
seq_id = seq_blk_id * BLOCK_SIZE_SEQ

seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None]
s_mask = seq_offset < SEQ
d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :]
d_mask = d_offset < DIM
y_blk = y_ptr + seq_offset * DIM + d_offset
x_blk = x_ptr + seq_offset * DIM + d_offset
mask = s_mask & d_mask

x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32)
mean_square = tl.sum(x * x, axis=1, keep_dims=True) / DIM
rstd = tl.math.rsqrt(mean_square + EPS)
w = tl.load(w_ptr + d_offset, mask=mask)
Comment thread
BBuf marked this conversation as resolved.
Outdated
tl.store(y_blk, x * rstd * w, mask=mask)


def triton_one_pass_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
shape = x.shape
x = x.contiguous()
y = torch.empty_like(x)
x_view = x.reshape(-1, shape[-1])
y_view = y.reshape(-1, shape[-1])
S, D = x_view.shape

BLOCK_SIZE_SEQ = min(16, triton.next_power_of_2(max(1, S // 512)))
grid = (triton.cdiv(S, BLOCK_SIZE_SEQ),)

with torch.cuda.device(x.device):
torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid](
y_view,
x_view,
w,
S,
D,
eps,
BLOCK_SIZE_DIM=triton.next_power_of_2(D),
BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ,
)
return y
42 changes: 37 additions & 5 deletions python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.normalization import AdaLayerNormContinuous

from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm
from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig
from sglang.multimodal_gen.runtime.layers.attention import USPAttention
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm
from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
NDRotaryEmbedding,
Expand Down Expand Up @@ -196,16 +197,47 @@ def forward(
key = key.unflatten(-1, (self.heads, -1))
value = value.unflatten(-1, (self.heads, -1))

query = self.norm_q(query)
key = self.norm_k(key)
if (
query.is_cuda
and (self.norm_q.variance_epsilon == self.norm_k.variance_epsilon)
and can_use_fused_inplace_qknorm(self.head_dim, query.dtype)
):
query, key = apply_qk_norm(
q=query,
k=key,
q_norm=self.norm_q,
k_norm=self.norm_k,
head_dim=self.head_dim,
allow_inplace=True,
)
else:
query = self.norm_q(query)
key = self.norm_k(key)

if self.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
encoder_key = encoder_key.unflatten(-1, (self.heads, -1))
encoder_value = encoder_value.unflatten(-1, (self.heads, -1))

encoder_query = self.norm_added_q(encoder_query)
encoder_key = self.norm_added_k(encoder_key)
if (
encoder_query.is_cuda
and (
self.norm_added_q.variance_epsilon
== self.norm_added_k.variance_epsilon
)
and can_use_fused_inplace_qknorm(self.head_dim, encoder_query.dtype)
):
encoder_query, encoder_key = apply_qk_norm(
q=encoder_query,
k=encoder_key,
q_norm=self.norm_added_q,
k_norm=self.norm_added_k,
head_dim=self.head_dim,
allow_inplace=True,
)
else:
encoder_query = self.norm_added_q(encoder_query)
encoder_key = self.norm_added_k(encoder_key)
Comment thread
BBuf marked this conversation as resolved.

query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
Expand Down
Loading