Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def postprocess_vae_encode(self, image_latents, vae):
def preprocess_decoding(self, latents, server_args=None, vae=None):
return latents

def gather_latents_for_sp(self, latents):
def gather_latents_for_sp(self, latents, batch=None):
# For video latents [B, C, T_local, H, W], gather along time dim=2
latents = sequence_model_parallel_all_gather(latents, dim=2)
return latents
Expand Down Expand Up @@ -808,7 +808,7 @@ def shard_latents_for_sp(self, batch, latents):
sharded_tensor = sharded_tensor[:, rank_in_sp_group, :, :]
return sharded_tensor, True

def gather_latents_for_sp(self, latents):
def gather_latents_for_sp(self, latents, batch=None):
# For image latents [B, S_local, D], gather along sequence dim=1
latents = sequence_model_parallel_all_gather(latents, dim=1)
return latents
Expand Down Expand Up @@ -862,11 +862,11 @@ def shard_latents_for_sp(self, batch, latents):
sharded = latents[:, :, h0:h1, :].contiguous()
return sharded, True

def gather_latents_for_sp(self, latents):
def gather_latents_for_sp(self, latents, batch=None):
if get_sp_world_size() <= 1:
return latents
if latents.dim() != 4:
return super().gather_latents_for_sp(latents)
return super().gather_latents_for_sp(latents, batch=batch)
# Gather along dim=2 (H') to match shard_latents_for_sp
return sequence_model_parallel_all_gather(latents, dim=2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,13 @@ def shard_latents_for_sp(self, batch, latents):

return latents, True

def gather_latents_for_sp(self, latents):
def gather_latents_for_sp(self, latents, batch=None):
"""Gather latents after SP. For packed token latents [B, S_local, D], gather on dim=1."""
if get_sp_world_size() <= 1:
return latents
if isinstance(latents, torch.Tensor) and latents.ndim == 3:
return sequence_model_parallel_all_gather(latents.contiguous(), dim=1)
return super().gather_latents_for_sp(latents)
return super().gather_latents_for_sp(latents, batch=batch)

def maybe_pack_audio_latents(self, latents, batch_size, batch):
# If already packed (3D shape [B, T, C*F]), skip packing
Expand Down
178 changes: 115 additions & 63 deletions python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable

import torch
import torch.distributed as dist

from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig
Expand All @@ -14,10 +15,8 @@
ImagePipelineConfig,
ModelTaskType,
)
from sglang.multimodal_gen.runtime.distributed.communication_op import (
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_sp_group,
get_sp_parallel_rank,
get_sp_world_size,
)
Expand Down Expand Up @@ -86,8 +85,13 @@ def _ceil_to_multiple(x: int, m: int) -> int:
return x
return int(math.ceil(x / m) * m)

@staticmethod
def _split_evenly(total: int, parts: int) -> list[int]:
base, remainder = divmod(total, parts)
return [base + int(rank < remainder) for rank in range(parts)]

def _build_zimage_sp_plan(self, batch) -> dict:
"""Build a minimal SP plan on batch for zimage spatial sharding."""
"""Build an SP plan that preserves native spatial layout for Z-Image."""
sp_size = get_sp_world_size()
rank = get_sp_parallel_rank()

Expand All @@ -103,32 +107,62 @@ def _build_zimage_sp_plan(self, batch) -> dict:
batch.width // self.vae_config.arch_config.spatial_compression_ratio
)

# Rule: shard along the larger spatial dimension (W/H), implemented via optional H/W transpose.
# Choose the larger of H and W for sharding, so H_eff = max(H, W).
swap_hw = W > H
H_eff = W if swap_hw else H
W_eff = H if swap_hw else W
# ZImage patchifies [C, F, H, W] latents in native F/H/W order, so shard
# native H or W directly.
H_tok = H // self.PATCH_SIZE
W_tok = W // self.PATCH_SIZE

shard_options = []
for shard_axis, axis_tok, other_tok, tie_break in (
("h", H_tok, W_tok, 0),
("w", W_tok, H_tok, 1),
):
axis_sizes = self._split_evenly(axis_tok, sp_size)
local_seq_lens = [axis_size * other_tok for axis_size in axis_sizes]
img_seq_target = self._ceil_to_multiple(
max(local_seq_lens), self.SEQ_LEN_MULTIPLE
)
total_pad_tokens = img_seq_target * sp_size - (H_tok * W_tok)
shard_options.append(
(
total_pad_tokens,
-axis_tok,
tie_break,
shard_axis,
axis_sizes,
img_seq_target,
)
)

_, _, _, shard_axis, axis_sizes, img_seq_target = min(shard_options)
axis_start_tok = sum(axis_sizes[:rank])
axis_local_tok = axis_sizes[rank]

# ZImage uses PATCH_SIZE=2 for spatial patchify; shard in token space and convert back to latent rows.
H_tok = H_eff // self.PATCH_SIZE
W_tok = W_eff // self.PATCH_SIZE
H_tok_pad = self._ceil_to_multiple(H_tok, sp_size)
H_tok_local = H_tok_pad // sp_size
h0_tok = rank * H_tok_local
if shard_axis == "h":
h0_tok = axis_start_tok
w0_tok = 0
local_h_tok = axis_local_tok
local_w_tok = W_tok
else:
h0_tok = 0
w0_tok = axis_start_tok
local_h_tok = H_tok
local_w_tok = axis_local_tok

plan = {
"sp_size": sp_size,
"rank": rank,
"swap_hw": swap_hw,
"H": H,
"W": W,
"H_eff": H_eff,
"W_eff": W_eff,
"H_tok": H_tok,
"W_tok": W_tok,
"H_tok_pad": H_tok_pad,
"H_tok_local": H_tok_local,
"shard_axis": shard_axis,
"shard_sizes_tok": axis_sizes,
"h0_tok": h0_tok,
"w0_tok": w0_tok,
"local_h_tok": local_h_tok,
"local_w_tok": local_w_tok,
"img_seq_target": img_seq_target,
}
batch._zimage_sp_plan = plan
return plan
Expand All @@ -154,51 +188,55 @@ def shard_latents_for_sp(self, batch, latents):
return latents, False

plan = self._get_zimage_sp_plan(batch)
if plan["shard_axis"] == "h":
h0 = plan["h0_tok"] * self.PATCH_SIZE
h1 = (plan["h0_tok"] + plan["local_h_tok"]) * self.PATCH_SIZE
return latents[:, :, :, h0:h1, :].contiguous(), True

# Layout: [B, C, T, H, W]. Always shard on dim=3 by optionally swapping H/W.
if plan["swap_hw"]:
latents = latents.transpose(3, 4).contiguous()

# Pad on effective-H so that H_tok is divisible by sp.
H_eff = latents.size(3)
w0 = plan["w0_tok"] * self.PATCH_SIZE
w1 = (plan["w0_tok"] + plan["local_w_tok"]) * self.PATCH_SIZE
return latents[:, :, :, :, w0:w1].contiguous(), True

H_tok = H_eff // self.PATCH_SIZE
pad_tok = plan["H_tok_pad"] - H_tok
pad_lat = pad_tok * self.PATCH_SIZE
if pad_lat > 0:
pad = latents[:, :, :, -1:, :].repeat(1, 1, 1, pad_lat, 1)
latents = torch.cat([latents, pad], dim=3)
h0 = plan["h0_tok"] * self.PATCH_SIZE
h1 = (plan["h0_tok"] + plan["H_tok_local"]) * self.PATCH_SIZE
latents = latents[:, :, :, h0:h1, :]

batch._zimage_sp_swap_hw = plan["swap_hw"]
return latents, True

def gather_latents_for_sp(self, latents):
# Gather on effective-H dim=3 (matches shard_latents_for_sp); swap-back is handled in post_denoising_loop.
def gather_latents_for_sp(self, latents, batch):
# Gather native H/W shards by padding to a common collective shape, then crop.
latents = latents.contiguous()
if get_sp_world_size() <= 1 or latents.dim() != 5:
if get_sp_world_size() <= 1 or latents.dim() not in (4, 5, 6):
return latents
return sequence_model_parallel_all_gather(latents, dim=3)

assert batch is not None
plan = self._get_zimage_sp_plan(batch)
if latents.dim() == 4:
shard_dim = 2 if plan["shard_axis"] == "h" else 3
elif latents.dim() == 5:
shard_dim = 3 if plan["shard_axis"] == "h" else 4
else:
shard_dim = 4 if plan["shard_axis"] == "h" else 5
max_axis_tok = max(plan["shard_sizes_tok"])
max_axis_lat = max_axis_tok * self.PATCH_SIZE

pad_shape = list(latents.shape)
pad_shape[shard_dim] = max_axis_lat
padded = latents.new_zeros(pad_shape)
axis_len = latents.shape[shard_dim]
padded_slices = [slice(None)] * latents.dim()
padded_slices[shard_dim] = slice(axis_len)
padded[tuple(padded_slices)] = latents

gathered = [torch.empty_like(padded) for _ in range(plan["sp_size"])]
dist.all_gather(gathered, padded, group=get_sp_group().device_group)

pieces = []
for rank, tensor in enumerate(gathered):
axis_lat = plan["shard_sizes_tok"][rank] * self.PATCH_SIZE
gather_slices = [slice(None)] * latents.dim()
gather_slices[shard_dim] = slice(axis_lat)
pieces.append(tensor[tuple(gather_slices)])
return torch.cat(pieces, dim=shard_dim)

def gather_noise_pred_for_sp(self, batch, noise_pred):
# Z-Image shards 5D latents on the effective-H axis, but ComfyUI noise_pred is 4D [B, C, H_local, W].
noise_pred = self.gather_latents_for_sp(noise_pred)
if noise_pred.dim() == 4:
# reconstruct the full spatial tensor
noise_pred = sequence_model_parallel_all_gather(
noise_pred.contiguous(), dim=2
)
# restore the original H/W orientation
if getattr(batch, "_zimage_sp_swap_hw", False):
noise_pred = noise_pred.transpose(2, 3).contiguous()
return noise_pred
return self.gather_latents_for_sp(noise_pred, batch=batch)

def post_denoising_loop(self, latents, batch):
# Restore swapped H/W and crop padded spatial dims before final reshape.
if latents.dim() == 5 and getattr(batch, "_zimage_sp_swap_hw", False):
latents = latents.transpose(3, 4).contiguous()
raw_latent_shape = getattr(batch, "raw_latent_shape", None)
if raw_latent_shape is not None and latents.dim() == 5:
latents = latents[:, :, :, : raw_latent_shape[3], : raw_latent_shape[4]]
Expand Down Expand Up @@ -239,16 +277,20 @@ def create_coordinate_grid(size, start=None, device=None):
).flatten(0, 2)
cap_freqs_cis = rotary_emb(cap_pos_ids)

# image (local, effective H-shard), offset after the full caption.
# Build image positions for the local native shard.
F_tokens = 1
H_tokens_local = plan["H_tok_local"]
W_tokens = plan["W_tok"]
H_tokens_local = plan["local_h_tok"]
W_tokens_local = plan["local_w_tok"]
img_pos_ids = create_coordinate_grid(
size=(F_tokens, H_tokens_local, W_tokens),
start=(cap_ori_len + cap_padding_len + 1, plan["h0_tok"], 0),
size=(F_tokens, H_tokens_local, W_tokens_local),
start=(
cap_ori_len + cap_padding_len + 1,
plan["h0_tok"],
plan["w0_tok"],
),
device=device,
).flatten(0, 2)
img_pad_len = (-img_pos_ids.shape[0]) % self.SEQ_LEN_MULTIPLE
img_pad_len = plan["img_seq_target"] - img_pos_ids.shape[0]
if img_pad_len:
pad_ids = create_coordinate_grid(
size=(1, 1, 1), start=(0, 0, 0), device=device
Expand Down Expand Up @@ -308,6 +350,11 @@ def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):
rotary_emb,
batch,
),
"image_seq_len_target": (
self._get_zimage_sp_plan(batch)["img_seq_target"]
if get_sp_world_size() > 1
else None
),
}

def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
Expand All @@ -320,4 +367,9 @@ def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
rotary_emb,
batch,
),
"image_seq_len_target": (
self._get_zimage_sp_plan(batch)["img_seq_target"]
if get_sp_world_size() > 1
else None
),
}
23 changes: 21 additions & 2 deletions python/sglang/multimodal_gen/runtime/models/dits/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,12 +718,19 @@ def create_coordinate_grid(size, start=None, device=None):
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)

@staticmethod
def _ceil_to_multiple(value: int, multiple: int) -> int:
if multiple <= 0:
return value
return int(math.ceil(value / multiple) * multiple)

def patchify_and_embed(
self,
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
image_seq_len_target: int | None = None,
):
assert len(all_image) == len(all_cap_feats) == 1

Expand Down Expand Up @@ -761,7 +768,12 @@ def patchify_and_embed(
F_tokens * H_tokens * W_tokens, pF * pH * pW * C
)
image_ori_len = image.size(0)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
min_image_seq_len = self._ceil_to_multiple(image_ori_len, SEQ_MULTI_OF)
if image_seq_len_target is None:
image_seq_len_target = min_image_seq_len
else:
image_seq_len_target = max(min_image_seq_len, image_seq_len_target)
image_padding_len = image_seq_len_target - image_ori_len

# padded feature
image_padded_feat = torch.cat(
Expand All @@ -788,6 +800,7 @@ def forward(
patch_size=2,
f_patch_size=1,
freqs_cis=None,
image_seq_len_target: int | None = None,
**kwargs,
):
assert patch_size in self.all_patch_size
Expand All @@ -807,7 +820,13 @@ def forward(
x_size,
x_valid_lens,
cap_valid_lens,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
) = self.patchify_and_embed(
x,
cap_feats,
patch_size,
f_patch_size,
image_seq_len_target=image_seq_len_target,
)

x = torch.cat(x, dim=0)
x, _ = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
Expand Down
Loading
Loading