diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index a9f5811e2936..4bd9368eead6 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -443,6 +443,15 @@ def apply_qk_norm( """ batch_size = q.size(0) + q_shape = q.shape + k_shape = k.shape + + # Ensure contiguous for view operation + if not q.is_contiguous(): + q = q.contiguous() + if not k.is_contiguous(): + k = k.contiguous() + q_eps = q_norm.variance_epsilon k_eps = k_norm.variance_epsilon # Only try fused path on CUDA and when it won't introduce implicit copies. @@ -460,7 +469,8 @@ def apply_qk_norm( head_dim=head_dim, eps=q_eps, ) - return q, k + # Inplace kernel modifies q, k - return with original shape + return q.view(q_shape), k.view(k_shape) # Fallback for AMD/ROCm: apply RMSNorm separately to q and k import warnings @@ -469,8 +479,6 @@ def apply_qk_norm( "Fused QK-norm not available, using RMSNorm fallback", stacklevel=2, ) - q_shape = q.shape - k_shape = k.shape q_out = q_norm(q.view(-1, head_dim)).view(q_shape) k_out = k_norm(k.view(-1, head_dim)).view(k_shape) return q_out, k_out diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index a157b88d301b..585c30cb5b84 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -22,7 +22,7 @@ from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig from sglang.multimodal_gen.runtime.layers.attention import USPAttention -from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( NDRotaryEmbedding, @@ -196,8 +196,7 @@ def forward( key = key.unflatten(-1, (self.heads, -1)) value = value.unflatten(-1, (self.heads, -1)) - query = self.norm_q(query) - key = self.norm_k(key) + query, key = apply_qk_norm(query, key, self.norm_q, self.norm_k, self.head_dim) if self.added_kv_proj_dim is not None: encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) @@ -340,8 +339,7 @@ def forward( key = key.unflatten(-1, (self.heads, -1)) value = value.unflatten(-1, (self.heads, -1)) - query = self.norm_q(query) - key = self.norm_k(key) + query, key = apply_qk_norm(query, key, self.norm_q, self.norm_k, self.head_dim) if freqs_cis is not None: cos, sin = freqs_cis