diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 82fbb76828fe..160768a5343b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -434,31 +434,62 @@ def apply_qk_norm( ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply QK normalization for query and key tensors. - Minimal multimodal_gen-only implementation: only the JIT fused inplace - QK-norm kernel path is supported (no fallback). + Prefer the fused inplace QK-norm kernel when applicable; + otherwise fall back to PyTorch norms. """ batch_size = q.size(0) q_eps = q_norm.variance_epsilon k_eps = k_norm.variance_epsilon - # Only try fused path on CUDA and when it won't introduce implicit copies. + + def can_view_as_bnhd(x: torch.Tensor) -> bool: + """Whether `x` can be viewed as [batch, *, head_dim] without a copy.""" + if ( + x.dim() < 2 + or x.size(0) != batch_size + or x.size(-1) != head_dim + or x.stride(-1) != 1 + or x.stride(-2) != head_dim + ): + return False + try: + x.view(batch_size, -1, head_dim) + return True + except RuntimeError: + return False + if ( q.is_cuda + and (torch.version.cuda is not None) and allow_inplace and (q_eps == k_eps) + and can_view_as_bnhd(q) + and can_view_as_bnhd(k) and can_use_fused_inplace_qknorm(head_dim) ): - fused_inplace_qknorm( - q=q.view(batch_size, -1, head_dim), - k=k.view(batch_size, -1, head_dim), - q_weight=q_norm.weight, - k_weight=k_norm.weight, - head_dim=head_dim, - eps=q_eps, - ) - return q, k - - raise RuntimeError( - "apply_qk_norm: fused inplace QK-norm is not applicable " - "(expected CUDA, contiguous q/k, matching eps, and supported head_dim)" - ) + try: + fused_inplace_qknorm( + q=q.view(batch_size, -1, head_dim), + k=k.view(batch_size, -1, head_dim), + q_weight=q_norm.weight, + k_weight=k_norm.weight, + head_dim=head_dim, + eps=q_eps, + ) + return q, k + except RuntimeError as e: + if "QK-norm is not applicable" not in str(e): + raise + + # Fallback: apply the PyTorch norms. + q_out = q_norm(q) + k_out = k_norm(k) + + if allow_inplace and q.is_contiguous() and q_out.shape == q.shape: + q.copy_(q_out) + q_out = q + if allow_inplace and k.is_contiguous() and k_out.shape == k.shape: + k.copy_(k_out) + k_out = k + + return q_out, k_out diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 73d478dcb217..1bee1cabd68e 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +from math import prod from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -485,6 +486,9 @@ def __init__( if self.qk_norm: self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + else: + self.norm_q = None + self.norm_k = None if added_kv_proj_dim is not None: self.add_q_proj = ReplicatedLinear( @@ -793,12 +797,6 @@ def forward( return encoder_hidden_states, hidden_states -def to_hashable(obj): - if isinstance(obj, list): - return tuple(to_hashable(x) for x in obj) - return obj - - class QwenImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): """ The Transformer model introduced in Qwen. @@ -948,9 +946,16 @@ def forward( timestep = (timestep / 1000).to(hidden_states.dtype) if self.zero_cond_t: - timestep = torch.cat([timestep, self.timestep_zero], dim=0) - device = timestep.device - modulate_index = self.build_modulate_index(to_hashable(img_shapes), device) + timestep = torch.cat([timestep, timestep * 0], dim=0) + # Use torch operations for GPU efficiency + modulate_index = torch.tensor( + [ + [0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) + for sample in img_shapes + ], + device=timestep.device, + dtype=torch.int, + ) else: modulate_index = None diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index 3ebd824486d0..5c01bc6b7d8c 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -228,22 +228,46 @@ def apply_qk_norm( batch_size = q.size(0) q_eps = q_norm.variance_epsilon k_eps = k_norm.variance_epsilon + + def can_view_as_bnhd(x: torch.Tensor) -> bool: + """Whether `x` can be viewed as [batch, *, head_dim] without a copy.""" + if ( + x.dim() < 2 + or x.size(0) != batch_size + or x.size(-1) != head_dim + or x.stride(-1) != 1 + or x.stride(-2) != head_dim + ): + return False + try: + x.view(batch_size, -1, head_dim) + return True + except RuntimeError: + return False + if ( - _is_cuda # TODO(dark): have not tested on ROCm or other backends - and allow_inplace # TODO(dark): this can be relaxed if needed - and (q_eps == k_eps) # TODO(dark): this can also be relaxed + _is_cuda + and (torch.version.cuda is not None) # avoid NVIDIA-only kernel on ROCm + and allow_inplace + and (q_eps == k_eps) and not envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get() + and can_view_as_bnhd(q) + and can_view_as_bnhd(k) and can_use_fused_inplace_qknorm(head_dim) ): - fused_inplace_qknorm( - q=q.view(batch_size, -1, head_dim), - k=k.view(batch_size, -1, head_dim), - q_weight=q_norm.weight, - k_weight=k_norm.weight, - head_dim=head_dim, - eps=q_eps, - ) - return q, k + try: + fused_inplace_qknorm( + q=q.view(batch_size, -1, head_dim), + k=k.view(batch_size, -1, head_dim), + q_weight=q_norm.weight, + k_weight=k_norm.weight, + head_dim=head_dim, + eps=q_eps, + ) + return q, k + except RuntimeError as e: + if "QK-norm is not applicable" not in str(e): + raise if alt_stream is not None and get_is_capture_mode(): current_stream = torch.cuda.current_stream()