Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class QwenImageVAEConfig(VAEConfig):
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False

use_parallel_decode: bool = False

def get_vae_scale_factor(self):
return 2 ** len(self.arch_config.temperal_downsample)

Expand Down
137 changes: 84 additions & 53 deletions python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@
apply_flashinfer_rope_qk_inplace,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

Expand Down Expand Up @@ -826,30 +829,52 @@ def _modulate(

shift, scale, gate = mod_params.chunk(3, dim=-1)
if index is not None:
actual_batch = x.shape[0]
shift0, shift1 = (
shift[:actual_batch],
shift[actual_batch : 2 * actual_batch],
)
scale0, scale1 = (
scale[:actual_batch],
scale[actual_batch : 2 * actual_batch],
)
gate0, gate1 = gate[:actual_batch], gate[actual_batch : 2 * actual_batch]
if not x.is_contiguous():
x = x.contiguous()
if not index.is_contiguous():
index = index.contiguous()
if is_scale_residual:
if not residual_x.is_contiguous():
residual_x = residual_x.contiguous()
if not gate_x.is_contiguous():
gate_x = gate_x.contiguous()
x, residual_out, gate_result = (
fuse_residual_layernorm_scale_shift_gate_select01_kernel(
# ROCm currently fails to compile the select01 Triton kernel, so
# keep using the torch.where fallback there.
if x.is_cuda and not current_platform.is_hip():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x.is_cuda and not current_platform.is_hip(): this pattern is appearing repeatedly recently, please move the implementation to platform.py, including all the changes recently

Copy link
Copy Markdown
Contributor Author

@gxxx-hum gxxx-hum Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll move it under runtime/platforms and clean up the recent similar cases there as well.

actual_batch = x.shape[0]
shift0, shift1 = (
shift[:actual_batch],
shift[actual_batch : 2 * actual_batch],
)
scale0, scale1 = (
scale[:actual_batch],
scale[actual_batch : 2 * actual_batch],
)
gate0, gate1 = (
gate[:actual_batch],
gate[actual_batch : 2 * actual_batch],
)
if not x.is_contiguous():
x = x.contiguous()
if not index.is_contiguous():
index = index.contiguous()
if is_scale_residual:
if not residual_x.is_contiguous():
residual_x = residual_x.contiguous()
if not gate_x.is_contiguous():
gate_x = gate_x.contiguous()
x, residual_out, gate_result = (
fuse_residual_layernorm_scale_shift_gate_select01_kernel(
x,
residual=residual_x,
residual_gate=gate_x,
weight=getattr(norm_module.norm, "weight", None),
bias=getattr(norm_module.norm, "bias", None),
scale0=scale0.contiguous(),
shift0=shift0.contiguous(),
gate0=gate0.contiguous(),
scale1=scale1.contiguous(),
shift1=shift1.contiguous(),
gate1=gate1.contiguous(),
index=index,
eps=norm_module.eps,
)
)
return x, residual_out, gate_result
else:
x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel(
x,
residual=residual_x,
residual_gate=gate_x,
weight=getattr(norm_module.norm, "weight", None),
bias=getattr(norm_module.norm, "bias", None),
scale0=scale0.contiguous(),
Expand All @@ -861,39 +886,45 @@ def _modulate(
index=index,
eps=norm_module.eps,
)
)
return x, residual_out, gate_result
return x, gate_result
else:
x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel(
x,
weight=getattr(norm_module.norm, "weight", None),
bias=getattr(norm_module.norm, "bias", None),
scale0=scale0.contiguous(),
shift0=shift0.contiguous(),
gate0=gate0.contiguous(),
scale1=scale1.contiguous(),
shift1=shift1.contiguous(),
gate1=gate1.contiguous(),
index=index,
eps=norm_module.eps,
actual_batch = x.shape[0]
shift0, shift1 = (
shift[:actual_batch],
shift[actual_batch : 2 * actual_batch],
)
return x, gate_result
scale0, scale1 = (
scale[:actual_batch],
scale[actual_batch : 2 * actual_batch],
)
gate0, gate1 = (
gate[:actual_batch],
gate[actual_batch : 2 * actual_batch],
)
index = index.to(dtype=torch.bool).unsqueeze(-1)
shift_result = torch.where(
index, shift1.unsqueeze(1), shift0.unsqueeze(1)
)
scale_result = torch.where(
index, scale1.unsqueeze(1), scale0.unsqueeze(1)
)
gate_result = torch.where(index, gate1.unsqueeze(1), gate0.unsqueeze(1))
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
if is_scale_residual:
modulated, residual_out = norm_module(
residual=residual_x,
x=x,
gate=gate_x,
shift=shift_result,
scale=scale_result,
)
return modulated, residual_out, gate_result
else:
modulated = norm_module(x=x, shift=shift_result, scale=scale_result)
return modulated, gate_result
if is_scale_residual:
modulated, residual_out = norm_module(
residual=residual_x,
x=x,
gate=gate_x,
shift=shift_result,
scale=scale_result,
)
return modulated, residual_out, gate_result
else:
modulated = norm_module(x=x, shift=shift_result, scale=scale_result)
return modulated, gate_result

def forward(
self,
Expand Down Expand Up @@ -1127,8 +1158,8 @@ def build_modulate_index(self, img_shapes: tuple[int, int, int], device):
first_size = sample[0][0] * sample[0][1] * sample[0][2]
total_size = sum(s[0] * s[1] * s[2] for s in sample)
if sp_world_size > 1:
first_local_size = _local_seq_len(first_size)
tail_local_size = _local_seq_len(total_size - first_size)
first_local_size = _local_seq_len(first_size, sp_world_size)
tail_local_size = _local_seq_len(total_size - first_size, sp_world_size)
idx = torch.cat(
[
torch.zeros(first_local_size, device=device, dtype=torch.int),
Expand Down
Loading
Loading