Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 85 additions & 53 deletions python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@
apply_flashinfer_rope_qk_inplace,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__) # pylint: disable=invalid-name


try:
from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import]
except Exception:
Expand Down Expand Up @@ -803,30 +807,52 @@ def _modulate(

shift, scale, gate = mod_params.chunk(3, dim=-1)
if index is not None:
actual_batch = x.shape[0]
shift0, shift1 = (
shift[:actual_batch],
shift[actual_batch : 2 * actual_batch],
)
scale0, scale1 = (
scale[:actual_batch],
scale[actual_batch : 2 * actual_batch],
)
gate0, gate1 = gate[:actual_batch], gate[actual_batch : 2 * actual_batch]
if not x.is_contiguous():
x = x.contiguous()
if not index.is_contiguous():
index = index.contiguous()
if is_scale_residual:
if not residual_x.is_contiguous():
residual_x = residual_x.contiguous()
if not gate_x.is_contiguous():
gate_x = gate_x.contiguous()
x, residual_out, gate_result = (
fuse_residual_layernorm_scale_shift_gate_select01_kernel(
# ROCm currently fails to compile the select01 Triton kernel, so
# keep using the torch.where fallback there.
if x.is_cuda and not current_platform.is_hip():
actual_batch = x.shape[0]
shift0, shift1 = (
shift[:actual_batch],
shift[actual_batch : 2 * actual_batch],
)
scale0, scale1 = (
scale[:actual_batch],
scale[actual_batch : 2 * actual_batch],
)
gate0, gate1 = (
gate[:actual_batch],
gate[actual_batch : 2 * actual_batch],
)
if not x.is_contiguous():
x = x.contiguous()
if not index.is_contiguous():
index = index.contiguous()
if is_scale_residual:
if not residual_x.is_contiguous():
residual_x = residual_x.contiguous()
if not gate_x.is_contiguous():
gate_x = gate_x.contiguous()
x, residual_out, gate_result = (
fuse_residual_layernorm_scale_shift_gate_select01_kernel(
x,
residual=residual_x,
residual_gate=gate_x,
weight=getattr(norm_module.norm, "weight", None),
bias=getattr(norm_module.norm, "bias", None),
scale0=scale0.contiguous(),
shift0=shift0.contiguous(),
gate0=gate0.contiguous(),
scale1=scale1.contiguous(),
shift1=shift1.contiguous(),
gate1=gate1.contiguous(),
index=index,
eps=norm_module.eps,
)
)
return x, residual_out, gate_result
else:
x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel(
x,
residual=residual_x,
residual_gate=gate_x,
weight=getattr(norm_module.norm, "weight", None),
bias=getattr(norm_module.norm, "bias", None),
scale0=scale0.contiguous(),
Expand All @@ -838,39 +864,45 @@ def _modulate(
index=index,
eps=norm_module.eps,
)
)
return x, residual_out, gate_result
return x, gate_result
else:
x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel(
x,
weight=getattr(norm_module.norm, "weight", None),
bias=getattr(norm_module.norm, "bias", None),
scale0=scale0.contiguous(),
shift0=shift0.contiguous(),
gate0=gate0.contiguous(),
scale1=scale1.contiguous(),
shift1=shift1.contiguous(),
gate1=gate1.contiguous(),
index=index,
eps=norm_module.eps,
actual_batch = x.shape[0]
shift0, shift1 = (
shift[:actual_batch],
shift[actual_batch : 2 * actual_batch],
)
return x, gate_result
scale0, scale1 = (
scale[:actual_batch],
scale[actual_batch : 2 * actual_batch],
)
gate0, gate1 = (
gate[:actual_batch],
gate[actual_batch : 2 * actual_batch],
)
index = index.to(dtype=torch.bool).unsqueeze(-1)
shift_result = torch.where(
index, shift1.unsqueeze(1), shift0.unsqueeze(1)
)
scale_result = torch.where(
index, scale1.unsqueeze(1), scale0.unsqueeze(1)
)
gate_result = torch.where(index, gate1.unsqueeze(1), gate0.unsqueeze(1))
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
if is_scale_residual:
modulated, residual_out = norm_module(
residual=residual_x,
x=x,
gate=gate_x,
shift=shift_result,
scale=scale_result,
)
return modulated, residual_out, gate_result
else:
modulated = norm_module(x=x, shift=shift_result, scale=scale_result)
return modulated, gate_result
if is_scale_residual:
modulated, residual_out = norm_module(
residual=residual_x,
x=x,
gate=gate_x,
shift=shift_result,
scale=scale_result,
)
return modulated, residual_out, gate_result
else:
modulated = norm_module(x=x, shift=shift_result, scale=scale_result)
return modulated, gate_result

def forward(
self,
Expand Down Expand Up @@ -1104,8 +1136,8 @@ def build_modulate_index(self, img_shapes: tuple[int, int, int], device):
first_size = sample[0][0] * sample[0][1] * sample[0][2]
total_size = sum(s[0] * s[1] * s[2] for s in sample)
if sp_world_size > 1:
first_local_size = _local_seq_len(first_size)
tail_local_size = _local_seq_len(total_size - first_size)
first_local_size = _local_seq_len(first_size, sp_world_size)
tail_local_size = _local_seq_len(total_size - first_size, sp_world_size)
idx = torch.cat(
[
torch.zeros(first_local_size, device=device, dtype=torch.int),
Expand Down
Loading