diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 0ac092d9e215..f46818ca3a91 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -27,11 +27,12 @@ ) from torch.nn import LayerNorm as LayerNorm +from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig from sglang.multimodal_gen.runtime.layers.attention import USPAttention # from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm as LayerNorm -from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( @@ -165,16 +166,47 @@ def forward( query = query.unflatten(-1, (self.heads, -1)) key = key.unflatten(-1, (self.heads, -1)) value = value.unflatten(-1, (self.heads, -1)) - query = self.norm_q(query) - key = self.norm_k(key) + if ( + query.is_cuda + and (self.norm_q.variance_epsilon == self.norm_k.variance_epsilon) + and can_use_fused_inplace_qknorm(self.head_dim, query.dtype) + ): + query, key = apply_qk_norm( + q=query, + k=key, + q_norm=self.norm_q, + k_norm=self.norm_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + else: + query = self.norm_q(query) + key = self.norm_k(key) if self.added_kv_proj_dim is not None: encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) - encoder_query = self.norm_added_q(encoder_query) - encoder_key = self.norm_added_k(encoder_key) + if ( + encoder_query.is_cuda + and ( + self.norm_added_q.variance_epsilon + == self.norm_added_k.variance_epsilon + ) + and can_use_fused_inplace_qknorm(self.head_dim, encoder_query.dtype) + ): + encoder_query, encoder_key = apply_qk_norm( + q=encoder_query, + k=encoder_key, + q_norm=self.norm_added_q, + k_norm=self.norm_added_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + else: + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) bsz, seq_len, _, _ = query.shape query = torch.cat([encoder_query, query], dim=1)