Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions python/sglang/multimodal_gen/configs/pipeline_configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from diffusers.image_processor import VaeImageProcessor
from einops import rearrange

from sglang.multimodal_gen.configs.models import (
DiTConfig,
Expand All @@ -18,6 +19,11 @@
)
from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput
from sglang.multimodal_gen.configs.utils import update_config_from_args
from sglang.multimodal_gen.runtime.distributed import (
get_sp_parallel_rank,
get_sp_world_size,
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import (
FlexibleArgumentParser,
Expand Down Expand Up @@ -59,10 +65,45 @@ def postprocess_text(output: BaseEncoderOutput, _text_inputs) -> torch.tensor:
raise NotImplementedError


def shard_rotary_emb_for_sp(emb):
"""
Shard rotary embeddings [S, D] along sequence for SP.
If S is not divisible by SP degree, pad by repeating the last row.
"""
# Sequence Parallelism: slice image RoPE to local shard if enabled
try:
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_sp_parallel_rank,
get_sp_world_size,
)

sp_world_size = get_sp_world_size()
except Exception:
sp_world_size = 1
seq_len = emb.shape[0]
if seq_len % sp_world_size != 0:
pad_len = sp_world_size - (seq_len % sp_world_size)
pad = emb[-1:].repeat(pad_len, 1)
emb = torch.cat([emb, pad], dim=0)
if sp_world_size > 1:
try:
rank = get_sp_parallel_rank()
except Exception:
rank = 0
seq_len = emb.shape[0]
local_len = seq_len // sp_world_size
start = rank * local_len
end = start + local_len
emb = emb[start:end]
return emb
else:
return emb


# config for a single pipeline
@dataclass
class PipelineConfig:
"""Base configuration for all pipeline architectures."""
"""The base configuration class for a generation pipeline."""

task_type: ModelTaskType

Expand Down Expand Up @@ -163,9 +204,28 @@ def prepare_latent_shape(self, batch, batch_size, num_frames):
return shape

# called after latents are prepared
def pack_latents(self, latents, batch_size, batch):
def maybe_pack_latents(self, latents, batch_size, batch):
return latents

def gather_latents_for_sp(self, latents):
# For video latents [B, C, T_local, H, W], gather along time dim=2
latents = sequence_model_parallel_all_gather(latents, dim=2)
return latents

def shard_latents_for_sp(self, batch, latents):
# general logic for video models
sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank()
if latents.dim() != 5:
return latents, False
time_dim = latents.shape[2]
if time_dim > 0 and time_dim % sp_world_size == 0:
sharded_tensor = rearrange(
latents, "b c (n t) h w -> b c n t h w", n=sp_world_size
).contiguous()
sharded_tensor = sharded_tensor[:, :, rank_in_sp_group, :, :, :]
return sharded_tensor, True
return latents, False

def get_pos_prompt_embeds(self, batch):
return batch.prompt_embeds

Expand Down Expand Up @@ -459,6 +519,55 @@ def update_pipeline_config(self, source_pipeline_dict: dict[str, Any]) -> None:
self.__post_init__()


@dataclass
class ImagePipelineConfig(PipelineConfig):
"""Base config for image generation pipelines with token-like latents [B, S, D]."""

def shard_latents_for_sp(self, batch, latents):
sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank()
seq_len = latents.shape[1]

# Pad to next multiple of SP degree if needed
if seq_len % sp_world_size != 0:
pad_len = sp_world_size - (seq_len % sp_world_size)
pad = torch.zeros(
(latents.shape[0], pad_len, latents.shape[2]),
dtype=latents.dtype,
device=latents.device,
)
latents = torch.cat([latents, pad], dim=1)
# Record padding length for later unpad
batch.sp_seq_pad = int(getattr(batch, "sp_seq_pad", 0)) + pad_len

sharded_tensor = rearrange(
latents, "b (n s) d -> b n s d", n=sp_world_size
).contiguous()
sharded_tensor = sharded_tensor[:, rank_in_sp_group, :, :]
return sharded_tensor, True

def gather_latents_for_sp(self, latents):
# For image latents [B, S_local, D], gather along sequence dim=1
latents = sequence_model_parallel_all_gather(latents, dim=1)
return latents

def _unpad_and_unpack_latents(self, latents, batch):
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
channels = self.dit_config.arch_config.in_channels
batch_size = latents.shape[0]

height = 2 * (int(batch.height) // (vae_scale_factor * 2))
width = 2 * (int(batch.width) // (vae_scale_factor * 2))

# If SP padding was applied, remove extra tokens before reshaping
target_tokens = (height // 2) * (width // 2)
if latents.shape[1] > target_tokens:
latents = latents[:, :target_tokens, :]

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
return latents, batch_size, channels, height, width


@dataclass
class SlidingTileAttnConfig(PipelineConfig):
"""Configuration for sliding tile attention."""
Expand Down
50 changes: 25 additions & 25 deletions python/sglang/multimodal_gen/configs/pipeline_configs/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,26 @@
)
from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig
from sglang.multimodal_gen.configs.pipeline_configs.base import (
ImagePipelineConfig,
ModelTaskType,
PipelineConfig,
preprocess_text,
shard_rotary_emb_for_sp,
)
from sglang.multimodal_gen.configs.pipeline_configs.hunyuan import (
clip_postprocess_text,
clip_preprocess_text,
)
from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import _pack_latents


def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
return outputs.last_hidden_state


@dataclass
class FluxPipelineConfig(PipelineConfig):
# FIXME: duplicate with SamplingParams.guidance_scale?
class FluxPipelineConfig(ImagePipelineConfig):
"""Configuration for the FLUX pipeline."""

embedded_cfg_scale: float = 3.5

task_type: ModelTaskType = ModelTaskType.T2I
Expand Down Expand Up @@ -82,21 +85,14 @@ def prepare_latent_shape(self, batch, batch_size, num_frames):
shape = (batch_size, num_channels_latents, height, width)
return shape

def pack_latents(self, latents, batch_size, batch):
def maybe_pack_latents(self, latents, batch_size, batch):
height = 2 * (
batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)
)
width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))
num_channels_latents = self.dit_config.arch_config.in_channels // 4
# pack latents
latents = latents.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
return latents
return _pack_latents(latents, batch_size, num_channels_latents, height, width)

def get_pos_prompt_embeds(self, batch):
return batch.prompt_embeds[1]
Expand Down Expand Up @@ -133,23 +129,27 @@ def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb):
original_width=width,
device=device,
)
ids = torch.cat([txt_ids, img_ids], dim=0).to(device=device)

# NOTE(mick): prepare it here, to avoid unnecessary computations
freqs_cis = rotary_emb.forward(ids)
return freqs_cis
img_cos, img_sin = rotary_emb.forward(img_ids)
img_cos = shard_rotary_emb_for_sp(img_cos)
img_sin = shard_rotary_emb_for_sp(img_sin)

txt_cos, txt_sin = rotary_emb.forward(txt_ids)

cos = torch.cat([txt_cos, img_cos], dim=0).to(device=device)
sin = torch.cat([txt_sin, img_sin], dim=0).to(device=device)
return cos, sin

def post_denoising_loop(self, latents, batch):
# unpack latents for flux
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
batch_size = latents.shape[0]
channels = latents.shape[-1]
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
height = 2 * (int(batch.height) // (vae_scale_factor * 2))
width = 2 * (int(batch.width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
(
latents,
batch_size,
channels,
height,
width,
) = self._unpad_and_unpack_latents(latents, batch)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents

Expand Down
Loading
Loading