diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index ea363021a62f..46ae0ca8dee2 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -47,12 +47,16 @@ apply_flashinfer_rope_qk_inplace, ) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT -from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) # pylint: disable=invalid-name + try: from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import] except Exception: @@ -803,30 +807,52 @@ def _modulate( shift, scale, gate = mod_params.chunk(3, dim=-1) if index is not None: - actual_batch = x.shape[0] - shift0, shift1 = ( - shift[:actual_batch], - shift[actual_batch : 2 * actual_batch], - ) - scale0, scale1 = ( - scale[:actual_batch], - scale[actual_batch : 2 * actual_batch], - ) - gate0, gate1 = gate[:actual_batch], gate[actual_batch : 2 * actual_batch] - if not x.is_contiguous(): - x = x.contiguous() - if not index.is_contiguous(): - index = index.contiguous() - if is_scale_residual: - if not residual_x.is_contiguous(): - residual_x = residual_x.contiguous() - if not gate_x.is_contiguous(): - gate_x = gate_x.contiguous() - x, residual_out, gate_result = ( - fuse_residual_layernorm_scale_shift_gate_select01_kernel( + # ROCm currently fails to compile the select01 Triton kernel, so + # keep using the torch.where fallback there. + if x.is_cuda and not current_platform.is_hip(): + actual_batch = x.shape[0] + shift0, shift1 = ( + shift[:actual_batch], + shift[actual_batch : 2 * actual_batch], + ) + scale0, scale1 = ( + scale[:actual_batch], + scale[actual_batch : 2 * actual_batch], + ) + gate0, gate1 = ( + gate[:actual_batch], + gate[actual_batch : 2 * actual_batch], + ) + if not x.is_contiguous(): + x = x.contiguous() + if not index.is_contiguous(): + index = index.contiguous() + if is_scale_residual: + if not residual_x.is_contiguous(): + residual_x = residual_x.contiguous() + if not gate_x.is_contiguous(): + gate_x = gate_x.contiguous() + x, residual_out, gate_result = ( + fuse_residual_layernorm_scale_shift_gate_select01_kernel( + x, + residual=residual_x, + residual_gate=gate_x, + weight=getattr(norm_module.norm, "weight", None), + bias=getattr(norm_module.norm, "bias", None), + scale0=scale0.contiguous(), + shift0=shift0.contiguous(), + gate0=gate0.contiguous(), + scale1=scale1.contiguous(), + shift1=shift1.contiguous(), + gate1=gate1.contiguous(), + index=index, + eps=norm_module.eps, + ) + ) + return x, residual_out, gate_result + else: + x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel( x, - residual=residual_x, - residual_gate=gate_x, weight=getattr(norm_module.norm, "weight", None), bias=getattr(norm_module.norm, "bias", None), scale0=scale0.contiguous(), @@ -838,39 +864,45 @@ def _modulate( index=index, eps=norm_module.eps, ) - ) - return x, residual_out, gate_result + return x, gate_result else: - x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel( - x, - weight=getattr(norm_module.norm, "weight", None), - bias=getattr(norm_module.norm, "bias", None), - scale0=scale0.contiguous(), - shift0=shift0.contiguous(), - gate0=gate0.contiguous(), - scale1=scale1.contiguous(), - shift1=shift1.contiguous(), - gate1=gate1.contiguous(), - index=index, - eps=norm_module.eps, + actual_batch = x.shape[0] + shift0, shift1 = ( + shift[:actual_batch], + shift[actual_batch : 2 * actual_batch], ) - return x, gate_result + scale0, scale1 = ( + scale[:actual_batch], + scale[actual_batch : 2 * actual_batch], + ) + gate0, gate1 = ( + gate[:actual_batch], + gate[actual_batch : 2 * actual_batch], + ) + index = index.to(dtype=torch.bool).unsqueeze(-1) + shift_result = torch.where( + index, shift1.unsqueeze(1), shift0.unsqueeze(1) + ) + scale_result = torch.where( + index, scale1.unsqueeze(1), scale0.unsqueeze(1) + ) + gate_result = torch.where(index, gate1.unsqueeze(1), gate0.unsqueeze(1)) else: shift_result = shift.unsqueeze(1) scale_result = scale.unsqueeze(1) gate_result = gate.unsqueeze(1) - if is_scale_residual: - modulated, residual_out = norm_module( - residual=residual_x, - x=x, - gate=gate_x, - shift=shift_result, - scale=scale_result, - ) - return modulated, residual_out, gate_result - else: - modulated = norm_module(x=x, shift=shift_result, scale=scale_result) - return modulated, gate_result + if is_scale_residual: + modulated, residual_out = norm_module( + residual=residual_x, + x=x, + gate=gate_x, + shift=shift_result, + scale=scale_result, + ) + return modulated, residual_out, gate_result + else: + modulated = norm_module(x=x, shift=shift_result, scale=scale_result) + return modulated, gate_result def forward( self, @@ -1104,8 +1136,8 @@ def build_modulate_index(self, img_shapes: tuple[int, int, int], device): first_size = sample[0][0] * sample[0][1] * sample[0][2] total_size = sum(s[0] * s[1] * s[2] for s in sample) if sp_world_size > 1: - first_local_size = _local_seq_len(first_size) - tail_local_size = _local_seq_len(total_size - first_size) + first_local_size = _local_seq_len(first_size, sp_world_size) + tail_local_size = _local_seq_len(total_size - first_size, sp_world_size) idx = torch.cat( [ torch.zeros(first_local_size, device=device, dtype=torch.int),