Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/sglang/multimodal_gen/runtime/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,15 @@ def apply_qk_norm(
"""

batch_size = q.size(0)
q_shape = q.shape
k_shape = k.shape

# Ensure contiguous for view operation
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
k = k.contiguous()

q_eps = q_norm.variance_epsilon
k_eps = k_norm.variance_epsilon
# Only try fused path on CUDA and when it won't introduce implicit copies.
Expand All @@ -460,7 +469,8 @@ def apply_qk_norm(
head_dim=head_dim,
eps=q_eps,
)
return q, k
# Inplace kernel modifies q, k - return with original shape
return q.view(q_shape), k.view(k_shape)

# Fallback for AMD/ROCm: apply RMSNorm separately to q and k
import warnings
Expand All @@ -469,8 +479,6 @@ def apply_qk_norm(
"Fused QK-norm not available, using RMSNorm fallback",
stacklevel=2,
)
q_shape = q.shape
k_shape = k.shape
q_out = q_norm(q.view(-1, head_dim)).view(q_shape)
k_out = k_norm(k.view(-1, head_dim)).view(k_shape)
return q_out, k_out
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig
from sglang.multimodal_gen.runtime.layers.attention import USPAttention
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm
from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
NDRotaryEmbedding,
Expand Down Expand Up @@ -196,8 +196,7 @@ def forward(
key = key.unflatten(-1, (self.heads, -1))
value = value.unflatten(-1, (self.heads, -1))

query = self.norm_q(query)
key = self.norm_k(key)
query, key = apply_qk_norm(query, key, self.norm_q, self.norm_k, self.head_dim)

if self.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
Expand Down Expand Up @@ -340,8 +339,7 @@ def forward(
key = key.unflatten(-1, (self.heads, -1))
value = value.unflatten(-1, (self.heads, -1))

query = self.norm_q(query)
key = self.norm_k(key)
query, key = apply_qk_norm(query, key, self.norm_q, self.norm_k, self.head_dim)

if freqs_cis is not None:
cos, sin = freqs_cis
Expand Down
Loading