diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index a9f5811e2936..27f040054b39 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -21,6 +21,7 @@ fuse_scale_shift_kernel, norm_infer, rms_norm_fn, + triton_one_pass_rms_norm, ) from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var @@ -76,7 +77,12 @@ def forward_cuda( fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x.view(shape), residual.view(residual_shape) else: - out = rmsnorm(x, self.weight.data, self.variance_epsilon) + if x.shape[-1] <= 128: + out = triton_one_pass_rms_norm( + x, self.weight.data, self.variance_epsilon + ) + else: + out = rmsnorm(x, self.weight.data, self.variance_epsilon) out = out.view(shape) return out diff --git a/python/sglang/multimodal_gen/runtime/layers/triton_ops.py b/python/sglang/multimodal_gen/runtime/layers/triton_ops.py index 761b13763856..aa0fc17b9d5a 100644 --- a/python/sglang/multimodal_gen/runtime/layers/triton_ops.py +++ b/python/sglang/multimodal_gen/runtime/layers/triton_ops.py @@ -1106,3 +1106,58 @@ def rms_norm_fn( out, residual_out, ) + + +# Adapted from https://github.com/ModelTC/LightX2V/blob/main/lightx2v/common/ops/norm/triton_ops.py#L905-L956 +@triton.jit +def _rms_norm_tiled_onepass( + y_ptr, + x_ptr, + w_ptr, + SEQ: tl.constexpr, + DIM: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE_SEQ: tl.constexpr, + BLOCK_SIZE_DIM: tl.constexpr, +): + seq_blk_id = tl.program_id(0) + seq_id = seq_blk_id * BLOCK_SIZE_SEQ + + seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None] + s_mask = seq_offset < SEQ + d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :] + d_mask = d_offset < DIM + y_blk = y_ptr + seq_offset * DIM + d_offset + x_blk = x_ptr + seq_offset * DIM + d_offset + mask = s_mask & d_mask + + x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32) + mean_square = tl.sum(x * x, axis=1, keep_dims=True) / DIM + rstd = tl.math.rsqrt(mean_square + EPS) + w = tl.load(w_ptr + d_offset, mask=d_mask) + tl.store(y_blk, x * rstd * w, mask=mask) + + +def triton_one_pass_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6): + shape = x.shape + x = x.contiguous() + y = torch.empty_like(x) + x_view = x.reshape(-1, shape[-1]) + y_view = y.reshape(-1, shape[-1]) + S, D = x_view.shape + + BLOCK_SIZE_SEQ = min(16, triton.next_power_of_2(max(1, S // 512))) + grid = (triton.cdiv(S, BLOCK_SIZE_SEQ),) + + with torch.cuda.device(x.device): + torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid]( + y_view, + x_view, + w, + S, + D, + eps, + BLOCK_SIZE_DIM=triton.next_power_of_2(D), + BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ, + ) + return y diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index a157b88d301b..c1b88f15461a 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -20,9 +20,10 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.normalization import AdaLayerNormContinuous +from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig from sglang.multimodal_gen.runtime.layers.attention import USPAttention -from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( NDRotaryEmbedding, @@ -196,16 +197,47 @@ def forward( key = key.unflatten(-1, (self.heads, -1)) value = value.unflatten(-1, (self.heads, -1)) - query = self.norm_q(query) - key = self.norm_k(key) + if ( + query.is_cuda + and (self.norm_q.variance_epsilon == self.norm_k.variance_epsilon) + and can_use_fused_inplace_qknorm(self.head_dim, query.dtype) + ): + query, key = apply_qk_norm( + q=query, + k=key, + q_norm=self.norm_q, + k_norm=self.norm_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + else: + query = self.norm_q(query) + key = self.norm_k(key) if self.added_kv_proj_dim is not None: encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) - encoder_query = self.norm_added_q(encoder_query) - encoder_key = self.norm_added_k(encoder_key) + if ( + encoder_query.is_cuda + and ( + self.norm_added_q.variance_epsilon + == self.norm_added_k.variance_epsilon + ) + and can_use_fused_inplace_qknorm(self.head_dim, encoder_query.dtype) + ): + encoder_query, encoder_key = apply_qk_norm( + q=encoder_query, + k=encoder_key, + q_norm=self.norm_added_q, + k_norm=self.norm_added_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + else: + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) query = torch.cat([encoder_query, query], dim=1) key = torch.cat([encoder_key, key], dim=1)