From ad2b46a4732ed6db721a7a4b0db6edc73da981f4 Mon Sep 17 00:00:00 2001 From: RubiaCx <1084281732@qq.com> Date: Tue, 30 Dec 2025 05:47:18 -0800 Subject: [PATCH 1/2] fix fallback for qk_norm --- .../runtime/layers/layernorm.py | 24 ++++++--- .../runtime/models/dits/qwen_image.py | 51 ++++++++----------- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 82fbb76828fe..cb66bc877fd0 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -434,17 +434,19 @@ def apply_qk_norm( ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply QK normalization for query and key tensors. - Minimal multimodal_gen-only implementation: only the JIT fused inplace - QK-norm kernel path is supported (no fallback). + Prefer the fused inplace QK-norm kernel when applicable; + otherwise fall back to PyTorch norms. """ batch_size = q.size(0) q_eps = q_norm.variance_epsilon k_eps = k_norm.variance_epsilon - # Only try fused path on CUDA and when it won't introduce implicit copies. if ( q.is_cuda + and (torch.version.cuda is not None) and allow_inplace + and q.is_contiguous() + and k.is_contiguous() and (q_eps == k_eps) and can_use_fused_inplace_qknorm(head_dim) ): @@ -458,7 +460,15 @@ def apply_qk_norm( ) return q, k - raise RuntimeError( - "apply_qk_norm: fused inplace QK-norm is not applicable " - "(expected CUDA, contiguous q/k, matching eps, and supported head_dim)" - ) + # Fallback: apply the PyTorch norms. + q_out = q_norm(q) + k_out = k_norm(k) + + if allow_inplace and q.is_contiguous() and q_out.shape == q.shape: + q.copy_(q_out) + q_out = q + if allow_inplace and k.is_contiguous() and k_out.shape == k.shape: + k.copy_(k_out) + k_out = k + + return q_out, k_out diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 73d478dcb217..150a64e01d64 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +from math import prod from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -17,11 +18,7 @@ from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig from sglang.multimodal_gen.runtime.distributed import get_local_torch_device from sglang.multimodal_gen.runtime.layers.attention import USPAttention -from sglang.multimodal_gen.runtime.layers.layernorm import ( - LayerNorm, - RMSNorm, - apply_qk_norm, -) +from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm, RMSNorm from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( apply_flashinfer_rope_qk_inplace, @@ -552,23 +549,14 @@ def forward( txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) # Apply QK normalization - if self.qk_norm: - img_query, img_key = apply_qk_norm( - q=img_query, - k=img_key, - q_norm=self.norm_q, - k_norm=self.norm_k, - head_dim=img_query.shape[-1], - allow_inplace=True, - ) - txt_query, txt_key = apply_qk_norm( - q=txt_query, - k=txt_key, - q_norm=self.norm_added_q, - k_norm=self.norm_added_k, - head_dim=txt_query.shape[-1], - allow_inplace=True, - ) + if self.norm_q is not None: + img_query = self.norm_q(img_query) + if self.norm_k is not None: + img_key = self.norm_k(img_key) + if self.norm_added_q is not None: + txt_query = self.norm_added_q(txt_query) + if self.norm_added_k is not None: + txt_key = self.norm_added_k(txt_key) # Apply RoPE if image_rotary_emb is not None: @@ -793,12 +781,6 @@ def forward( return encoder_hidden_states, hidden_states -def to_hashable(obj): - if isinstance(obj, list): - return tuple(to_hashable(x) for x in obj) - return obj - - class QwenImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): """ The Transformer model introduced in Qwen. @@ -948,9 +930,16 @@ def forward( timestep = (timestep / 1000).to(hidden_states.dtype) if self.zero_cond_t: - timestep = torch.cat([timestep, self.timestep_zero], dim=0) - device = timestep.device - modulate_index = self.build_modulate_index(to_hashable(img_shapes), device) + timestep = torch.cat([timestep, timestep * 0], dim=0) + # Use torch operations for GPU efficiency + modulate_index = torch.tensor( + [ + [0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) + for sample in img_shapes + ], + device=timestep.device, + dtype=torch.int, + ) else: modulate_index = None From baa534f70b94be637f77387791e78f6aa3cf5813 Mon Sep 17 00:00:00 2001 From: RubiaCx <1084281732@qq.com> Date: Fri, 2 Jan 2026 20:12:41 -0800 Subject: [PATCH 2/2] make fused QK-norm optional with safe fallback --- .../runtime/layers/layernorm.py | 43 ++++++++++++----- .../runtime/models/dits/qwen_image.py | 34 +++++++++---- python/sglang/srt/models/utils.py | 48 ++++++++++++++----- 3 files changed, 93 insertions(+), 32 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index cb66bc877fd0..160768a5343b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -441,24 +441,45 @@ def apply_qk_norm( batch_size = q.size(0) q_eps = q_norm.variance_epsilon k_eps = k_norm.variance_epsilon + + def can_view_as_bnhd(x: torch.Tensor) -> bool: + """Whether `x` can be viewed as [batch, *, head_dim] without a copy.""" + if ( + x.dim() < 2 + or x.size(0) != batch_size + or x.size(-1) != head_dim + or x.stride(-1) != 1 + or x.stride(-2) != head_dim + ): + return False + try: + x.view(batch_size, -1, head_dim) + return True + except RuntimeError: + return False + if ( q.is_cuda and (torch.version.cuda is not None) and allow_inplace - and q.is_contiguous() - and k.is_contiguous() and (q_eps == k_eps) + and can_view_as_bnhd(q) + and can_view_as_bnhd(k) and can_use_fused_inplace_qknorm(head_dim) ): - fused_inplace_qknorm( - q=q.view(batch_size, -1, head_dim), - k=k.view(batch_size, -1, head_dim), - q_weight=q_norm.weight, - k_weight=k_norm.weight, - head_dim=head_dim, - eps=q_eps, - ) - return q, k + try: + fused_inplace_qknorm( + q=q.view(batch_size, -1, head_dim), + k=k.view(batch_size, -1, head_dim), + q_weight=q_norm.weight, + k_weight=k_norm.weight, + head_dim=head_dim, + eps=q_eps, + ) + return q, k + except RuntimeError as e: + if "QK-norm is not applicable" not in str(e): + raise # Fallback: apply the PyTorch norms. q_out = q_norm(q) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 150a64e01d64..1bee1cabd68e 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -18,7 +18,11 @@ from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig from sglang.multimodal_gen.runtime.distributed import get_local_torch_device from sglang.multimodal_gen.runtime.layers.attention import USPAttention -from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm, RMSNorm +from sglang.multimodal_gen.runtime.layers.layernorm import ( + LayerNorm, + RMSNorm, + apply_qk_norm, +) from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( apply_flashinfer_rope_qk_inplace, @@ -482,6 +486,9 @@ def __init__( if self.qk_norm: self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + else: + self.norm_q = None + self.norm_k = None if added_kv_proj_dim is not None: self.add_q_proj = ReplicatedLinear( @@ -549,14 +556,23 @@ def forward( txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) # Apply QK normalization - if self.norm_q is not None: - img_query = self.norm_q(img_query) - if self.norm_k is not None: - img_key = self.norm_k(img_key) - if self.norm_added_q is not None: - txt_query = self.norm_added_q(txt_query) - if self.norm_added_k is not None: - txt_key = self.norm_added_k(txt_key) + if self.qk_norm: + img_query, img_key = apply_qk_norm( + q=img_query, + k=img_key, + q_norm=self.norm_q, + k_norm=self.norm_k, + head_dim=img_query.shape[-1], + allow_inplace=True, + ) + txt_query, txt_key = apply_qk_norm( + q=txt_query, + k=txt_key, + q_norm=self.norm_added_q, + k_norm=self.norm_added_k, + head_dim=txt_query.shape[-1], + allow_inplace=True, + ) # Apply RoPE if image_rotary_emb is not None: diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index 3ebd824486d0..5c01bc6b7d8c 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -228,22 +228,46 @@ def apply_qk_norm( batch_size = q.size(0) q_eps = q_norm.variance_epsilon k_eps = k_norm.variance_epsilon + + def can_view_as_bnhd(x: torch.Tensor) -> bool: + """Whether `x` can be viewed as [batch, *, head_dim] without a copy.""" + if ( + x.dim() < 2 + or x.size(0) != batch_size + or x.size(-1) != head_dim + or x.stride(-1) != 1 + or x.stride(-2) != head_dim + ): + return False + try: + x.view(batch_size, -1, head_dim) + return True + except RuntimeError: + return False + if ( - _is_cuda # TODO(dark): have not tested on ROCm or other backends - and allow_inplace # TODO(dark): this can be relaxed if needed - and (q_eps == k_eps) # TODO(dark): this can also be relaxed + _is_cuda + and (torch.version.cuda is not None) # avoid NVIDIA-only kernel on ROCm + and allow_inplace + and (q_eps == k_eps) and not envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get() + and can_view_as_bnhd(q) + and can_view_as_bnhd(k) and can_use_fused_inplace_qknorm(head_dim) ): - fused_inplace_qknorm( - q=q.view(batch_size, -1, head_dim), - k=k.view(batch_size, -1, head_dim), - q_weight=q_norm.weight, - k_weight=k_norm.weight, - head_dim=head_dim, - eps=q_eps, - ) - return q, k + try: + fused_inplace_qknorm( + q=q.view(batch_size, -1, head_dim), + k=k.view(batch_size, -1, head_dim), + q_weight=q_norm.weight, + k_weight=k_norm.weight, + head_dim=head_dim, + eps=q_eps, + ) + return q, k + except RuntimeError as e: + if "QK-norm is not applicable" not in str(e): + raise if alt_stream is not None and get_is_capture_mode(): current_stream = torch.cuda.current_stream()