diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py index d3c86f557045..55b3db3bc25d 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py @@ -9,6 +9,7 @@ import torch from diffusers.image_processor import VaeImageProcessor +from einops import rearrange from sglang.multimodal_gen.configs.models import ( DiTConfig, @@ -18,6 +19,11 @@ ) from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput from sglang.multimodal_gen.configs.utils import update_config_from_args +from sglang.multimodal_gen.runtime.distributed import ( + get_sp_parallel_rank, + get_sp_world_size, + sequence_model_parallel_all_gather, +) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger from sglang.multimodal_gen.utils import ( FlexibleArgumentParser, @@ -59,10 +65,45 @@ def postprocess_text(output: BaseEncoderOutput, _text_inputs) -> torch.tensor: raise NotImplementedError +def shard_rotary_emb_for_sp(emb): + """ + Shard rotary embeddings [S, D] along sequence for SP. + If S is not divisible by SP degree, pad by repeating the last row. + """ + # Sequence Parallelism: slice image RoPE to local shard if enabled + try: + from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_parallel_rank, + get_sp_world_size, + ) + + sp_world_size = get_sp_world_size() + except Exception: + sp_world_size = 1 + seq_len = emb.shape[0] + if seq_len % sp_world_size != 0: + pad_len = sp_world_size - (seq_len % sp_world_size) + pad = emb[-1:].repeat(pad_len, 1) + emb = torch.cat([emb, pad], dim=0) + if sp_world_size > 1: + try: + rank = get_sp_parallel_rank() + except Exception: + rank = 0 + seq_len = emb.shape[0] + local_len = seq_len // sp_world_size + start = rank * local_len + end = start + local_len + emb = emb[start:end] + return emb + else: + return emb + + # config for a single pipeline @dataclass class PipelineConfig: - """Base configuration for all pipeline architectures.""" + """The base configuration class for a generation pipeline.""" task_type: ModelTaskType @@ -163,9 +204,28 @@ def prepare_latent_shape(self, batch, batch_size, num_frames): return shape # called after latents are prepared - def pack_latents(self, latents, batch_size, batch): + def maybe_pack_latents(self, latents, batch_size, batch): return latents + def gather_latents_for_sp(self, latents): + # For video latents [B, C, T_local, H, W], gather along time dim=2 + latents = sequence_model_parallel_all_gather(latents, dim=2) + return latents + + def shard_latents_for_sp(self, batch, latents): + # general logic for video models + sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank() + if latents.dim() != 5: + return latents, False + time_dim = latents.shape[2] + if time_dim > 0 and time_dim % sp_world_size == 0: + sharded_tensor = rearrange( + latents, "b c (n t) h w -> b c n t h w", n=sp_world_size + ).contiguous() + sharded_tensor = sharded_tensor[:, :, rank_in_sp_group, :, :, :] + return sharded_tensor, True + return latents, False + def get_pos_prompt_embeds(self, batch): return batch.prompt_embeds @@ -459,6 +519,55 @@ def update_pipeline_config(self, source_pipeline_dict: dict[str, Any]) -> None: self.__post_init__() +@dataclass +class ImagePipelineConfig(PipelineConfig): + """Base config for image generation pipelines with token-like latents [B, S, D].""" + + def shard_latents_for_sp(self, batch, latents): + sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank() + seq_len = latents.shape[1] + + # Pad to next multiple of SP degree if needed + if seq_len % sp_world_size != 0: + pad_len = sp_world_size - (seq_len % sp_world_size) + pad = torch.zeros( + (latents.shape[0], pad_len, latents.shape[2]), + dtype=latents.dtype, + device=latents.device, + ) + latents = torch.cat([latents, pad], dim=1) + # Record padding length for later unpad + batch.sp_seq_pad = int(getattr(batch, "sp_seq_pad", 0)) + pad_len + + sharded_tensor = rearrange( + latents, "b (n s) d -> b n s d", n=sp_world_size + ).contiguous() + sharded_tensor = sharded_tensor[:, rank_in_sp_group, :, :] + return sharded_tensor, True + + def gather_latents_for_sp(self, latents): + # For image latents [B, S_local, D], gather along sequence dim=1 + latents = sequence_model_parallel_all_gather(latents, dim=1) + return latents + + def _unpad_and_unpack_latents(self, latents, batch): + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + channels = self.dit_config.arch_config.in_channels + batch_size = latents.shape[0] + + height = 2 * (int(batch.height) // (vae_scale_factor * 2)) + width = 2 * (int(batch.width) // (vae_scale_factor * 2)) + + # If SP padding was applied, remove extra tokens before reshaping + target_tokens = (height // 2) * (width // 2) + if latents.shape[1] > target_tokens: + latents = latents[:, :target_tokens, :] + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + return latents, batch_size, channels, height, width + + @dataclass class SlidingTileAttnConfig(PipelineConfig): """Configuration for sliding tile attention.""" diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py b/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py index 9c377689bfd3..60d194d9bdab 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py @@ -14,14 +14,16 @@ ) from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ImagePipelineConfig, ModelTaskType, - PipelineConfig, preprocess_text, + shard_rotary_emb_for_sp, ) from sglang.multimodal_gen.configs.pipeline_configs.hunyuan import ( clip_postprocess_text, clip_preprocess_text, ) +from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import _pack_latents def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: @@ -29,8 +31,9 @@ def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tenso @dataclass -class FluxPipelineConfig(PipelineConfig): - # FIXME: duplicate with SamplingParams.guidance_scale? +class FluxPipelineConfig(ImagePipelineConfig): + """Configuration for the FLUX pipeline.""" + embedded_cfg_scale: float = 3.5 task_type: ModelTaskType = ModelTaskType.T2I @@ -82,21 +85,14 @@ def prepare_latent_shape(self, batch, batch_size, num_frames): shape = (batch_size, num_channels_latents, height, width) return shape - def pack_latents(self, latents, batch_size, batch): + def maybe_pack_latents(self, latents, batch_size, batch): height = 2 * ( batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) ) width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) num_channels_latents = self.dit_config.arch_config.in_channels // 4 # pack latents - latents = latents.view( - batch_size, num_channels_latents, height // 2, 2, width // 2, 2 - ) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape( - batch_size, (height // 2) * (width // 2), num_channels_latents * 4 - ) - return latents + return _pack_latents(latents, batch_size, num_channels_latents, height, width) def get_pos_prompt_embeds(self, batch): return batch.prompt_embeds[1] @@ -133,23 +129,27 @@ def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb): original_width=width, device=device, ) - ids = torch.cat([txt_ids, img_ids], dim=0).to(device=device) + # NOTE(mick): prepare it here, to avoid unnecessary computations - freqs_cis = rotary_emb.forward(ids) - return freqs_cis + img_cos, img_sin = rotary_emb.forward(img_ids) + img_cos = shard_rotary_emb_for_sp(img_cos) + img_sin = shard_rotary_emb_for_sp(img_sin) + + txt_cos, txt_sin = rotary_emb.forward(txt_ids) + + cos = torch.cat([txt_cos, img_cos], dim=0).to(device=device) + sin = torch.cat([txt_sin, img_sin], dim=0).to(device=device) + return cos, sin def post_denoising_loop(self, latents, batch): # unpack latents for flux - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - batch_size = latents.shape[0] - channels = latents.shape[-1] - vae_scale_factor = self.vae_config.arch_config.vae_scale_factor - height = 2 * (int(batch.height) // (vae_scale_factor * 2)) - width = 2 * (int(batch.width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) + ( + latents, + batch_size, + channels, + height, + width, + ) = self._unpad_and_unpack_latents(latents, batch) latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py b/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py index 10bd4e610b9f..d89bb7397066 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py @@ -10,8 +10,9 @@ from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ImagePipelineConfig, ModelTaskType, - PipelineConfig, + shard_rotary_emb_for_sp, ) from sglang.multimodal_gen.utils import calculate_dimensions @@ -64,9 +65,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): @dataclass -class QwenImagePipelineConfig(PipelineConfig): - should_use_guidance: bool = False +class QwenImagePipelineConfig(ImagePipelineConfig): + """Configuration for the QwenImage pipeline.""" + should_use_guidance: bool = False task_type: ModelTaskType = ModelTaskType.T2I vae_tiling: bool = False @@ -105,15 +107,14 @@ def get_vae_scale_factor(self): return self.vae_config.arch_config.vae_scale_factor def prepare_latent_shape(self, batch, batch_size, num_frames): - height = 2 * ( - batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) - ) - width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + height = 2 * (batch.height // (vae_scale_factor * 2)) + width = 2 * (batch.width // (vae_scale_factor * 2)) num_channels_latents = self.dit_config.arch_config.in_channels // 4 - shape = (batch_size, num_channels_latents, height, width) + shape = (batch_size, 1, num_channels_latents, height, width) return shape - def pack_latents(self, latents, batch_size, batch): + def maybe_pack_latents(self, latents, batch_size, batch): height = 2 * ( batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) ) @@ -124,6 +125,7 @@ def pack_latents(self, latents, batch_size, batch): @staticmethod def get_freqs_cis(img_shapes, txt_seq_lens, rotary_emb, device, dtype): + # img_shapes: for global entire image img_freqs, txt_freqs = rotary_emb(img_shapes, txt_seq_lens, device=device) img_cos, img_sin = ( @@ -134,139 +136,128 @@ def get_freqs_cis(img_shapes, txt_seq_lens, rotary_emb, device, dtype): txt_freqs.real.to(dtype=dtype), txt_freqs.imag.to(dtype=dtype), ) + return (img_cos, img_sin), (txt_cos, txt_sin) - def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): - batch_size = batch.latents.shape[0] + def _prepare_cond_kwargs(self, batch, prompt_embeds, rotary_emb, device, dtype): + batch_size = prompt_embeds[0].shape[0] + height = batch.height + width = batch.width vae_scale_factor = self.vae_config.arch_config.vae_scale_factor img_shapes = [ [ ( 1, - batch.height // vae_scale_factor // 2, - batch.width // vae_scale_factor // 2, + height // vae_scale_factor // 2, + width // vae_scale_factor // 2, ) ] ] * batch_size - txt_seq_lens = [batch.prompt_embeds[0].shape[1]] - return { - "img_shapes": img_shapes, - "txt_seq_lens": txt_seq_lens, - "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( - img_shapes, txt_seq_lens, rotary_emb, device, dtype - ), - } - - def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): - batch_size = batch.latents.shape[0] - vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + txt_seq_lens = [prompt_embeds[0].shape[1]] - img_shapes = [ - [ - ( - 1, - batch.height // vae_scale_factor // 2, - batch.width // vae_scale_factor // 2, - ) - ] - ] * batch_size + (img_cos, img_sin), (txt_cos, txt_sin) = self.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ) - txt_seq_lens = [batch.negative_prompt_embeds[0].shape[1]] + img_cos = shard_rotary_emb_for_sp(img_cos) + img_sin = shard_rotary_emb_for_sp(img_sin) return { - "img_shapes": img_shapes, "txt_seq_lens": txt_seq_lens, - "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( - img_shapes, txt_seq_lens, rotary_emb, device, dtype - ), + "freqs_cis": ((img_cos, img_sin), (txt_cos, txt_sin)), } - def post_denoising_loop(self, latents, batch): - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - batch_size = latents.shape[0] - channels = latents.shape[-1] - vae_scale_factor = self.vae_config.arch_config.vae_scale_factor - height = 2 * (int(batch.height) // (vae_scale_factor * 2)) - width = 2 * (int(batch.width) // (vae_scale_factor * 2)) + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_cond_kwargs( + batch, batch.prompt_embeds, rotary_emb, device, dtype + ) + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_cond_kwargs( + batch, batch.negative_prompt_embeds, rotary_emb, device, dtype + ) - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) + def post_denoising_loop(self, latents, batch): + # unpack latents for qwen-image + ( + latents, + batch_size, + channels, + height, + width, + ) = self._unpad_and_unpack_latents(latents, batch) latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) return latents class QwenImageEditPipelineConfig(QwenImagePipelineConfig): + """Configuration for the QwenImageEdit pipeline.""" + task_type: ModelTaskType = ModelTaskType.I2I - def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): - # TODO: lots of duplications here + def _prepare_edit_cond_kwargs( + self, batch, prompt_embeds, rotary_emb, device, dtype + ): batch_size = batch.latents.shape[0] + assert batch_size == 1 height = batch.height width = batch.width image = batch.pil_image image_size = image[0].size if isinstance(image, list) else image.size - calculated_width, calculated_height, _ = calculate_dimensions( + edit_width, edit_height, _ = calculate_dimensions( 1024 * 1024, image_size[0] / image_size[1] ) vae_scale_factor = self.get_vae_scale_factor() + img_shapes = [ [ - (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2), ( 1, - calculated_height // vae_scale_factor // 2, - calculated_width // vae_scale_factor // 2, + height // vae_scale_factor // 2, + width // vae_scale_factor // 2, ), - ] - ] * batch_size - txt_seq_lens = [batch.prompt_embeds[0].shape[1]] - return { - "img_shapes": img_shapes, - "txt_seq_lens": txt_seq_lens, - "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( - img_shapes, txt_seq_lens, rotary_emb, device, dtype - ), - } - - def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): - batch_size = batch.latents.shape[0] - height = batch.height - width = batch.width - image = batch.pil_image - image_size = image[0].size if isinstance(image, list) else image.size - calculated_width, calculated_height, _ = calculate_dimensions( - 1024 * 1024, image_size[0] / image_size[1] - ) - vae_scale_factor = self.get_vae_scale_factor() - img_shapes = [ - [ - (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2), ( 1, - calculated_height // vae_scale_factor // 2, - calculated_width // vae_scale_factor // 2, + edit_height // vae_scale_factor // 2, + edit_width // vae_scale_factor // 2, ), - ] + ], ] * batch_size + txt_seq_lens = [prompt_embeds[0].shape[1]] + (img_cos, img_sin), (txt_cos, txt_sin) = QwenImagePipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ) + + # perform sp shard on noisy image tokens + noisy_img_seq_len = ( + 1 * (height // vae_scale_factor // 2) * (width // vae_scale_factor // 2) + ) + + noisy_img_cos = shard_rotary_emb_for_sp(img_cos[:noisy_img_seq_len, :]) + noisy_img_sin = shard_rotary_emb_for_sp(img_sin[:noisy_img_seq_len, :]) + + # concat back the img_cos for input image (since it is not sp-shared later) + img_cos = torch.cat([noisy_img_cos, img_cos[noisy_img_seq_len:, :]], dim=0).to( + device=device + ) + img_sin = torch.cat([noisy_img_sin, img_sin[noisy_img_seq_len:, :]], dim=0).to( + device=device + ) - txt_seq_lens = [batch.negative_prompt_embeds[0].shape[1]] return { - "img_shapes": img_shapes, "txt_seq_lens": txt_seq_lens, - "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( - img_shapes, txt_seq_lens, rotary_emb, device, dtype - ), + "freqs_cis": ((img_cos, img_sin), (txt_cos, txt_sin)), } - def prepare_latent_shape(self, batch, batch_size, num_frames): - vae_scale_factor = self.vae_config.arch_config.vae_scale_factor - height = 2 * (batch.height // (vae_scale_factor * 2)) + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_edit_cond_kwargs( + batch, batch.prompt_embeds, rotary_emb, device, dtype + ) - width = 2 * (batch.width // (vae_scale_factor * 2)) - num_channels_latents = self.dit_config.arch_config.in_channels // 4 - shape = (batch_size, 1, num_channels_latents, height, width) - return shape + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_edit_cond_kwargs( + batch, batch.negative_prompt_embeds, rotary_emb, device, dtype + ) def preprocess_image(self, image, image_processor): image_size = image[0].size if isinstance(image, list) else image.size @@ -290,5 +281,6 @@ def adjust_size(self, width, height, image): return width, height def slice_noise_pred(self, noise, latents): + # remove noise over input image noise = noise[:, : latents.size(1)] return noise diff --git a/python/sglang/multimodal_gen/configs/sample/base.py b/python/sglang/multimodal_gen/configs/sample/base.py index 0445271c9320..18b4ea276aa3 100644 --- a/python/sglang/multimodal_gen/configs/sample/base.py +++ b/python/sglang/multimodal_gen/configs/sample/base.py @@ -507,6 +507,7 @@ def _merge_with_user_params(self, user_params): if user_params is None: return + # user is not allowed to modify any param defined in the SamplingParams subclass subclass_defined_fields = set(type(self).__annotations__.keys()) # Compare against current instance to avoid constructing a default instance diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py index f8b7c28bd92e..945cbe81aa60 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -284,7 +284,7 @@ def generate( # TODO: send batch when supported for request_idx, req in enumerate(requests): logger.info( - "Processing prompt %d/%d: %s...", + "Processing prompt: %d/%d: %s", request_idx + 1, len(requests), req.prompt[:100], diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/layer.py b/python/sglang/multimodal_gen/runtime/layers/attention/layer.py index b7faea7895a7..df4f377dfa56 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/layer.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/layer.py @@ -170,7 +170,7 @@ def forward( replicated_k: torch.Tensor | None = None, replicated_v: torch.Tensor | None = None, gate_compress: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> torch.Tensor: """Forward pass for distributed attention. Args: @@ -212,16 +212,14 @@ def forward( q, k, v, gate_compress=gate_compress, attn_metadata=ctx_attn_metadata ) # type: ignore[call-arg] - # Redistribute back if using sequence parallelism - replicated_output = None - # Apply backend-specific postprocess_output output = self.attn_impl.postprocess_output(output, ctx_attn_metadata) output = sequence_model_parallel_all_to_all_4D( output, scatter_dim=1, gather_dim=2 ) - return output, replicated_output + + return output class LocalAttention(nn.Module): @@ -309,7 +307,7 @@ def __init__( causal: bool = False, supported_attention_backends: set[AttentionBackendEnum] | None = None, prefix: str = "", - dropout_p: float = 0.0, + dropout_rate: float = 0.0, **extra_impl_args, ) -> None: super().__init__() @@ -341,7 +339,7 @@ def __init__( self.backend = backend_name_to_enum(attn_backend.get_name()) self.dtype = dtype self.causal = causal - self.dropout_p = dropout_p + self.dropout_p = dropout_rate def forward( self, @@ -351,7 +349,7 @@ def forward( replicated_q: torch.Tensor | None = None, replicated_k: torch.Tensor | None = None, replicated_v: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> torch.Tensor: """ Forward pass for USPAttention. @@ -367,7 +365,7 @@ def forward( if get_sequence_parallel_world_size() == 1: # No sequence parallelism, just run local attention. out = self.attn_impl.forward(q, k, v, ctx_attn_metadata) - return out, None + return out # Ulysses-style All-to-All for sequence/head sharding if get_ulysses_parallel_world_size() > 1: @@ -395,4 +393,4 @@ def forward( # -> [B, S_local, H, D] out = _usp_output_all_to_all(out, head_dim=2) - return out, None + return out diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 75fa2d3ebe07..ab31450512ac 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -32,7 +32,7 @@ from torch.nn import LayerNorm as LayerNorm from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig -from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.attention import USPAttention # from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm as LayerNorm from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm @@ -149,17 +149,17 @@ def __init__( self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias) # Scaled dot product attention - self.attn = LocalAttention( + self.attn = USPAttention( num_heads=num_heads, head_size=self.head_dim, dropout_rate=0, softmax_scale=None, causal=False, - supported_attention_backends=( + supported_attention_backends={ AttentionBackendEnum.FA, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN, - ), + }, ) def forward( diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index c8d67ace4ba9..989d6d5286b1 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -14,7 +14,7 @@ from diffusers.models.normalization import AdaLayerNormContinuous from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig -from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.attention import USPAttention from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm, RMSNorm from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear from sglang.multimodal_gen.runtime.layers.triton_ops import ( @@ -282,7 +282,7 @@ def __init__( self.norm_added_k = RMSNorm(head_dim, eps=eps) # Scaled dot product attention - self.attn = LocalAttention( + self.attn = USPAttention( num_heads=num_heads, head_size=self.head_dim, dropout_rate=0, @@ -301,7 +301,7 @@ def forward( image_rotary_emb: tuple[torch.Tensor, torch.Tensor], **cross_attention_kwargs, ): - seq_txt = encoder_hidden_states.shape[1] + seq_len_txt = encoder_hidden_states.shape[1] # Compute QKV for image stream (sample projections) img_query, _ = self.to_q(hidden_states) @@ -366,8 +366,8 @@ def forward( joint_hidden_states = joint_hidden_states.to(joint_query.dtype) # Split attention outputs back - txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part - img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] # Image part # Apply output projections img_attn_output, _ = self.to_out[0](img_attn_output) @@ -568,7 +568,6 @@ def forward( encoder_hidden_states: torch.Tensor = None, encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, - img_shapes: Optional[List[Tuple[int, int, int]]] = None, txt_seq_lens: Optional[List[int]] = None, freqs_cis: tuple[torch.Tensor, torch.Tensor] = None, guidance: torch.Tensor = None, # TODO: this should probably be removed diff --git a/python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py index 19286e644852..529c4995d2d8 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py @@ -252,7 +252,7 @@ def forward( q = self._apply_rope(q, cos, sin) k = self._apply_rope(k, cos, sin) - output, _ = self.attn(q, k, v) # [B,heads,S,D] + output = self.attn(q, k, v) # [B,heads,S,D] output = rearrange(output, "b s h d -> b s (h d)") output, _ = self.wo(output) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 54f996499c6f..cb674e49195b 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -13,7 +13,6 @@ from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size from sglang.multimodal_gen.runtime.layers.attention import ( - LocalAttention, UlyssesAttention_VSA, USPAttention, ) @@ -138,7 +137,7 @@ def __init__( self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() # Scaled dot product attention - self.attn = LocalAttention( + self.attn = USPAttention( num_heads=num_heads, head_size=self.head_dim, dropout_rate=0, @@ -391,7 +390,7 @@ def forward( query, key = _apply_rotary_emb( query, cos, sin, is_neox_style=False ), _apply_rotary_emb(key, cos, sin, is_neox_style=False) - attn_output, _ = self.attn1(query, key, value) + attn_output = self.attn1(query, key, value) attn_output = attn_output.flatten(2) attn_output, _ = self.to_out(attn_output) attn_output = attn_output.squeeze(1) @@ -560,7 +559,7 @@ def forward( query, cos, sin, is_neox_style=False ), _apply_rotary_emb(key, cos, sin, is_neox_style=False) - attn_output, _ = self.attn1(query, key, value, gate_compress=gate_compress) + attn_output = self.attn1(query, key, value, gate_compress=gate_compress) attn_output = attn_output.flatten(2) attn_output, _ = self.to_out(attn_output) attn_output = attn_output.squeeze(1) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index ea315ed14206..6743a72b0247 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -335,7 +335,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): ) # Handle sequence parallelism AFTER TI2V processing - self._preprocess_sp_latents(batch) + self._preprocess_sp_latents(batch, server_args) latents = batch.latents # Shard z and reserved_frames_mask for TI2V if SP is enabled @@ -524,38 +524,29 @@ def _post_denoising_loop( torch.mps.current_allocated_memory(), ) - def _preprocess_sp_latents(self, batch: Req): + def _preprocess_sp_latents(self, batch: Req, server_args: ServerArgs): """Shard latents for Sequence Parallelism if applicable.""" - sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank() if get_sp_world_size() <= 1: - batch.did_sp_shard_latents = False return - def _shard_tensor( - tensor: torch.Tensor | None, - ) -> tuple[torch.Tensor | None, bool]: - if tensor is None: - return None, False - - if tensor.dim() == 5: - time_dim = tensor.shape[2] - if time_dim > 0 and time_dim % sp_world_size == 0: - sharded_tensor = rearrange( - tensor, "b c (n t) h w -> b c n t h w", n=sp_world_size - ).contiguous() - sharded_tensor = sharded_tensor[:, :, rank_in_sp_group, :, :, :] - return sharded_tensor, True - - # For 4D image tensors or unsharded 5D tensors, return as is. - return tensor, False - - batch.latents, did_shard = _shard_tensor(batch.latents) - batch.did_sp_shard_latents = did_shard + if batch.latents is not None: + ( + batch.latents, + did_shard, + ) = server_args.pipeline_config.shard_latents_for_sp(batch, batch.latents) + batch.did_sp_shard_latents = did_shard + else: + batch.did_sp_shard_latents = False - # image_latent is sharded independently, but the decision to all-gather later - # is based on whether the main `latents` was sharded. - if batch.image_latent is not None: - batch.image_latent, _ = _shard_tensor(batch.image_latent) + # For I2I tasks like QwenImageEdit, the image_latent (input image) should be + # replicated on all SP ranks, not sharded, as it provides global context. + if ( + server_args.pipeline_config.task_type != ModelTaskType.I2I + and batch.image_latent is not None + ): + batch.image_latent, _ = server_args.pipeline_config.shard_latents_for_sp( + batch, batch.image_latent + ) def _postprocess_sp_latents( self, @@ -565,13 +556,20 @@ def _postprocess_sp_latents( ) -> tuple[torch.Tensor, torch.Tensor | None]: """Gather latents after Sequence Parallelism if they were sharded.""" if get_sp_world_size() > 1 and getattr(batch, "did_sp_shard_latents", False): - latents = sequence_model_parallel_all_gather(latents, dim=2) + latents = self.server_args.pipeline_config.gather_latents_for_sp(latents) if trajectory_tensor is not None: - # trajectory_tensor shape: [b, num_steps, c, t_local, h, w] -> gather on dim 3 + # trajectory_tensor shapes: + # - video: [b, num_steps, c, t_local, h, w] -> gather on dim=3 + # - image: [b, num_steps, s_local, d] -> gather on dim=2 trajectory_tensor = trajectory_tensor.to(get_local_torch_device()) + gather_dim = 3 if trajectory_tensor.dim() >= 5 else 2 trajectory_tensor = sequence_model_parallel_all_gather( - trajectory_tensor, dim=3 + trajectory_tensor, dim=gather_dim ) + if gather_dim == 2 and hasattr(batch, "raw_latent_shape"): + orig_s = batch.raw_latent_shape[1] + if trajectory_tensor.shape[2] > orig_s: + trajectory_tensor = trajectory_tensor[:, :, :orig_s, :] return latents, trajectory_tensor def start_profile(self, batch: Req): diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py index 53ec3d1c2248..40112e68ca3f 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py @@ -104,7 +104,7 @@ def forward( image = batch.pil_image - # preprocess the imag_processor + # preprocess via vae_image_processor prompt_image = server_args.pipeline_config.preprocess_image( image, self.vae_image_processor ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py index 34d7d82d45dc..4252da83b42d 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py @@ -87,7 +87,7 @@ def forward( latents = randn_tensor( shape, generator=generator, device=device, dtype=dtype ) - latents = server_args.pipeline_config.pack_latents( + latents = server_args.pipeline_config.maybe_pack_latents( latents, batch_size, batch ) else: diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 6b05fb7a5cb8..e23dba802cc2 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -793,18 +793,6 @@ def get_provided_args( return provided_args def check_server_sp_args(self): - - if self.pipeline_config.task_type.is_image_gen(): - if ( - (self.sp_degree and self.sp_degree > 1) - or (self.ulysses_degree and self.ulysses_degree > 1) - or (self.ring_degree and self.ring_degree > 1) - ): - raise ValueError( - "SP is not supported for image generation models for now" - ) - self.sp_degree = self.ulysses_degree = self.ring_degree = 1 - if self.sp_degree == -1: # assume we leave all remaining gpus to sp num_gpus_per_group = self.dp_size * self.tp_size @@ -861,8 +849,11 @@ def check_server_sp_args(self): def check_server_dp_args(self): assert self.num_gpus % self.dp_size == 0, f"{self.num_gpus=}, {self.dp_size=}" assert self.dp_size >= 1, "--dp-size must be natural number" - self.dp_degree = self.num_gpus // self.dp_size + # NOTE: disable temporarily + # self.dp_degree = self.num_gpus // self.dp_size logger.info(f"Setting dp_degree to: {self.dp_degree}") + if self.dp_size > 1: + raise ValueError("DP is not yet supported") def check_server_args(self) -> None: """Validate inference arguments for consistency""" @@ -920,18 +911,6 @@ def check_server_args(self) -> None: self.pipeline_config.check_pipeline_config() - # Add preprocessing config validation if needed - if self.mode == ExecutionMode.PREPROCESS: - if self.preprocess_config is None: - raise ValueError( - "preprocess_config is not set in ServerArgs when mode is PREPROCESS" - ) - if self.preprocess_config.model_path == "": - self.preprocess_config.model_path = self.model_path - if not self.pipeline_config.vae_config.load_encoder: - self.pipeline_config.vae_config.load_encoder = True - self.preprocess_config.check_preprocess_config() - # parallelism self.check_server_dp_args() # allocate all remaining gpus for sp-size diff --git a/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py b/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py index 800828ccf761..1f34920fc2e7 100644 --- a/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py +++ b/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py @@ -15,25 +15,40 @@ class TestFlux_T2V(TestGenerateBase): extra_args = [] data_type: DataType = DataType.IMAGE thresholds = { - "test_single_gpu": 6.90 * 1.05, + "test_single_gpu": 6.5 * 1.05, + "test_usp": 8.3 * 1.05, } + def test_cfg_parallel(self): + pass + + def test_mixed(self): + pass + class TestQwenImage(TestGenerateBase): model_path = "Qwen/Qwen-Image" extra_args = [] data_type: DataType = DataType.IMAGE thresholds = { - "test_single_gpu": 11.7 * 1.05, + "test_single_gpu": 10.4 * 1.05, + "test_usp": 20.2 * 1.05, } + def test_cfg_parallel(self): + pass + + def test_mixed(self): + pass + class TestQwenImageEdit(TestGenerateBase): model_path = "Qwen/Qwen-Image-Edit" extra_args = [] data_type: DataType = DataType.IMAGE thresholds = { - "test_single_gpu": 43.5 * 1.05, + "test_single_gpu": 33.4 * 1.05, + "test_usp": 26.9 * 1.05, } prompt: str | None = ( @@ -57,13 +72,11 @@ def setUp(self): f"--output-path={self.output_path}", ] + [f"--image-path={img_path}"] - def test_single_gpu(self): - self._run_test( - name=f"{self.model_name()}_single_gpu", - args=None, - model_path=self.model_path, - test_key="test_single_gpu", - ) + def test_cfg_parallel(self): + pass + + def test_mixed(self): + pass if __name__ == "__main__": diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index 7fd88fd678b8..0d30b1b3059b 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -9,7 +9,7 @@ "denoise_stage": 0.05, "non_denoise_stage": 0.4, "denoise_step": 0.2, - "denoise_agg": 0.08 + "denoise_agg": 0.1 }, "improvement_reporting": { "threshold": 0.2 @@ -96,6 +96,72 @@ "49": 410.42 } }, + "qwen_image_t2i_2_gpus": { + "stages_ms": { + "InputValidationStage": 0.04, + "TextEncodingStage": 693.2, + "ConditioningStage": 0.02, + "TimestepPreparationStage": 2.84, + "LatentPreparationStage": 9.13, + "DenoisingStage": 24529.77, + "DecodingStage": 612.79 + }, + "denoise_step_ms": { + "0": 405.94, + "1": 420.06, + "2": 414.79, + "3": 392.4, + "4": 408.14, + "5": 605.0, + "6": 469.39, + "7": 574.04, + "8": 539.61, + "9": 452.93, + "10": 279.36, + "11": 271.8, + "12": 438.26, + "13": 552.65, + "14": 576.1, + "15": 679.84, + "16": 543.0, + "17": 512.81, + "18": 522.27, + "19": 545.06, + "20": 545.85, + "21": 523.83, + "22": 519.36, + "23": 513.78, + "24": 532.54, + "25": 524.94, + "26": 542.59, + "27": 570.91, + "28": 568.73, + "29": 564.52, + "30": 564.57, + "31": 544.94, + "32": 496.81, + "33": 488.98, + "34": 457.18, + "35": 441.42, + "36": 437.44, + "37": 477.6, + "38": 429.17, + "39": 465.55, + "40": 448.25, + "41": 511.83, + "42": 450.6, + "43": 375.78, + "44": 504.4, + "45": 524.44, + "46": 535.22, + "47": 514.52, + "48": 431.58, + "49": 410.68 + }, + "expected_e2e_ms": 25850.45, + "expected_avg_denoise_ms": 490.43, + "expected_median_denoise_ms": 512.32 + }, "flux_image_t2i": { "stages_ms": { "InputValidationStage": 0.03, @@ -162,6 +228,72 @@ "expected_avg_denoise_ms": 165.83, "expected_median_denoise_ms": 169.33 }, + "flux_image_t2i_2_gpus": { + "stages_ms": { + "InputValidationStage": 0.03, + "TextEncodingStage": 74.47, + "ConditioningStage": 0.01, + "TimestepPreparationStage": 2.23, + "LatentPreparationStage": 6.17, + "DenoisingStage": 8400.49, + "DecodingStage": 381.56 + }, + "denoise_step_ms": { + "0": 166.27, + "1": 59.6, + "2": 167.31, + "3": 168.7, + "4": 168.83, + "5": 171.05, + "6": 174.64, + "7": 170.92, + "8": 169.69, + "9": 169.21, + "10": 167.71, + "11": 177.62, + "12": 166.44, + "13": 174.61, + "14": 170.43, + "15": 169.47, + "16": 167.24, + "17": 169.15, + "18": 169.51, + "19": 172.3, + "20": 172.19, + "21": 172.36, + "22": 168.39, + "23": 168.47, + "24": 170.55, + "25": 170.96, + "26": 168.43, + "27": 169.01, + "28": 169.62, + "29": 170.95, + "30": 171.83, + "31": 171.92, + "32": 170.1, + "33": 170.46, + "34": 169.91, + "35": 168.91, + "36": 170.27, + "37": 170.23, + "38": 169.62, + "39": 169.66, + "40": 169.57, + "41": 169.42, + "42": 168.59, + "43": 171.12, + "44": 169.6, + "45": 169.93, + "46": 171.23, + "47": 171.03, + "48": 170.14, + "49": 169.4 + }, + "expected_e2e_ms": 9006.3, + "expected_avg_denoise_ms": 167.89, + "expected_median_denoise_ms": 169.67 + }, "qwen_image_edit_ti2i": { "notes": "single uploaded reference image, Qwen/Qwen-Image-Edit", "expected_e2e_ms": 138500.0, @@ -465,197 +597,195 @@ }, "wan2_1_i2v_14b_480P_2gpu": { "stages_ms": { - "InputValidationStage": 32.94, - "TextEncodingStage": 2316.5, - "ImageEncodingStage": 3026.2, + "InputValidationStage": 33.57, + "TextEncodingStage": 2424.73, + "ImageEncodingStage": 3462.55, "ConditioningStage": 0.01, "TimestepPreparationStage": 2.69, "LatentPreparationStage": 9.73, "ImageVAEEncodingStage": 2290.98, - "DenoisingStage": 385080.09, - "DecodingStage": 2984.69, - "per_frame_generation": null + "DenoisingStage": 414428.85, + "DecodingStage": 3016.1 }, "denoise_step_ms": { - "0": 8785.36, - "1": 7644.16, - "2": 7687.27, - "3": 7703.9, - "4": 7710.61, - "5": 7716.32, - "6": 7714.26, - "7": 7711.27, - "8": 7711.08, - "9": 7706.57, - "10": 7700.78, - "11": 7696.03, - "12": 7704.73, - "13": 7699.99, - "14": 7705.33, - "15": 7701.11, - "16": 7704.04, - "17": 7695.31, - "18": 7693.63, - "19": 7686.34, - "20": 7683.27, - "21": 7689.82, - "22": 7688.74, - "23": 7686.01, - "24": 7675.43, - "25": 7679.86, - "26": 7676.75, - "27": 7671.65, - "28": 7667.0, - "29": 7669.83, - "30": 7660.5, - "31": 7666.82, - "32": 7660.89, - "33": 7668.75, - "34": 7662.27, - "35": 7659.71, - "36": 7661.36, - "37": 7664.87, - "38": 7666.93, - "39": 7661.05, - "40": 7661.88, - "41": 7657.96, - "42": 7660.6, - "43": 7669.82, - "44": 7655.78, - "45": 7654.25, - "46": 7656.56, - "47": 7652.37, - "48": 7657.61, - "49": 7644.6 + "0": 9304.67, + "1": 8218.78, + "2": 8269.27, + "3": 8291.59, + "4": 8308.29, + "5": 8300.75, + "6": 8302.76, + "7": 8297.95, + "8": 8295.26, + "9": 8296.45, + "10": 8287.48, + "11": 8275.98, + "12": 8281.9, + "13": 8283.39, + "14": 8264.96, + "15": 8275.66, + "16": 8271.89, + "17": 8273.77, + "18": 8279.34, + "19": 8271.89, + "20": 8265.83, + "21": 8259.99, + "22": 8260.36, + "23": 8270.06, + "24": 8271.58, + "25": 8272.39, + "26": 8267.87, + "27": 8277.09, + "28": 8264.49, + "29": 8266.14, + "30": 8263.67, + "31": 8273.82, + "32": 8260.5, + "33": 8268.44, + "34": 8253.2, + "35": 8244.32, + "36": 8258.15, + "37": 8256.65, + "38": 8255.48, + "39": 8260.09, + "40": 8250.99, + "41": 8253.52, + "42": 8247.39, + "43": 8252.7, + "44": 8243.67, + "45": 8251.94, + "46": 8258.73, + "47": 8240.57, + "48": 8249.64, + "49": 8248.14 }, - "expected_e2e_ms": 395758.23, - "expected_avg_denoise_ms": 7701.42, - "expected_median_denoise_ms": 7676.09 + "expected_e2e_ms": 425569.98, + "expected_avg_denoise_ms": 8288.39, + "expected_median_denoise_ms": 8267.01 }, "wan2_1_i2v_14b_720P_2gpu": { "stages_ms": { "InputValidationStage": 53.67, "TextEncodingStage": 2838, "ImageEncodingStage": 3123.99, - "ConditioningStage": 0.02, + "ConditioningStage": 0.01, "TimestepPreparationStage": 3.39, - "LatentPreparationStage": 6.68, + "LatentPreparationStage": 8.41, "ImageVAEEncodingStage": 2261.05, - "DenoisingStage": 386761.14, - "DecodingStage": 2968.35, - "per_frame_generation": null + "DenoisingStage": 417418.12, + "DecodingStage": 2968.35 }, "denoise_step_ms": { - "0": 10021.98, - "1": 7633.62, - "2": 7676.46, - "3": 7704.68, - "4": 7725.09, - "5": 7732.86, - "6": 7735.42, - "7": 7739.05, - "8": 7740.89, - "9": 7724.35, - "10": 7730.2, - "11": 7713.23, - "12": 7715.93, - "13": 7710.93, - "14": 7699.95, - "15": 7704.72, - "16": 7704.03, - "17": 7700.47, - "18": 7702.0, - "19": 7705.92, - "20": 7704.35, - "21": 7705.11, - "22": 7693.85, - "23": 7696.91, - "24": 7689.6, - "25": 7681.2, - "26": 7675.63, - "27": 7678.95, - "28": 7683.82, - "29": 7681.07, - "30": 7671.07, - "31": 7674.65, - "32": 7679.56, - "33": 7674.59, - "34": 7672.16, - "35": 7679.68, - "36": 7670.81, - "37": 7661.84, - "38": 7668.58, - "39": 7667.1, - "40": 7670.22, - "41": 7664.97, - "42": 7667.3, - "43": 7668.87, - "44": 7663.43, - "45": 7656.34, - "46": 7662.81, - "47": 7662.05, - "48": 7654.13, - "49": 7648.62 + "0": 11848.08, + "1": 8220.3, + "2": 8274.3, + "3": 8298.9, + "4": 8303.34, + "5": 8322.44, + "6": 8314.37, + "7": 8318.54, + "8": 8304.94, + "9": 8303.04, + "10": 8305.22, + "11": 8296.22, + "12": 8289.2, + "13": 8294.19, + "14": 8294.87, + "15": 8285.96, + "16": 8284.98, + "17": 8281.61, + "18": 8277.35, + "19": 8287.46, + "20": 8280.3, + "21": 8279.18, + "22": 8279.37, + "23": 8280.16, + "24": 8282.67, + "25": 8272.14, + "26": 8279.37, + "27": 8271.66, + "28": 8274.6, + "29": 8272.88, + "30": 8273.76, + "31": 8266.17, + "32": 8267.77, + "33": 8266.88, + "34": 8263.14, + "35": 8265.97, + "36": 8267.76, + "37": 8268.03, + "38": 8262.24, + "39": 8261.4, + "40": 8263.65, + "41": 8272.46, + "42": 8254.9, + "43": 8261.03, + "44": 8252.92, + "45": 8262.49, + "46": 8253.67, + "47": 8254.92, + "48": 8257.08, + "49": 8236.56 }, - "expected_e2e_ms": 397541.45, - "expected_avg_denoise_ms": 7735.02, - "expected_median_denoise_ms": 7681.14 + "expected_e2e_ms": 427536.9, + "expected_avg_denoise_ms": 8348.21, + "expected_median_denoise_ms": 8274.45 }, "wan2_2_t2v_a14b_2gpu": { "stages_ms": { - "InputValidationStage": 0.09, - "TextEncodingStage": 2322.57, - "ConditioningStage": 0.03, - "TimestepPreparationStage": 2.29, - "LatentPreparationStage": 3.08, - "DenoisingStage": 79913.08, - "DecodingStage": 1339.58 + "InputValidationStage": 0.07, + "TextEncodingStage": 2507.83, + "ConditioningStage": 0.02, + "TimestepPreparationStage": 3.22, + "LatentPreparationStage": 2.99, + "DenoisingStage": 103136.69, + "DecodingStage": 1431.71 }, "denoise_step_ms": { - "0": 19269.37, - "1": 691.64, - "2": 699.28, - "3": 696.55, - "4": 698.6, - "5": 704.56, - "6": 699.26, - "7": 700.84, - "8": 700.27, - "9": 704.15, - "10": 699.04, - "11": 704.79, - "12": 701.48, - "13": 707.24, - "14": 697.54, - "15": 698.89, - "16": 697.97, - "17": 699.34, - "18": 697.68, - "19": 697.42, - "20": 697.14, - "21": 700.14, - "22": 696.75, - "23": 702.36, - "24": 697.3, - "25": 703.97, - "26": 33676.93, - "27": 700.4, - "28": 703.68, - "29": 691.86, - "30": 706.1, - "31": 704.18, - "32": 700.34, - "33": 698.62, - "34": 698.66, - "35": 699.77, - "36": 700.96, - "37": 701.02, - "38": 703.98, - "39": 702.18 + "0": 24471.86, + "1": 757.31, + "2": 760.07, + "3": 758.74, + "4": 762.4, + "5": 755.83, + "6": 760.06, + "7": 756.38, + "8": 755.38, + "9": 754.25, + "10": 754.51, + "11": 753.46, + "12": 753.67, + "13": 753.08, + "14": 754.83, + "15": 753.04, + "16": 754.28, + "17": 754.45, + "18": 758.19, + "19": 756.23, + "20": 755.14, + "21": 755.92, + "22": 759.52, + "23": 762.09, + "24": 756.8, + "25": 758.86, + "26": 48787.27, + "27": 758.5, + "28": 757.57, + "29": 757.16, + "30": 758.43, + "31": 763.31, + "32": 753.69, + "33": 754.91, + "34": 752.03, + "35": 763.65, + "36": 760.96, + "37": 754.31, + "38": 753.64, + "39": 756.95 }, - "expected_e2e_ms": 83595.94, - "expected_avg_denoise_ms": 1988.81, - "expected_median_denoise_ms": 700.2 + "expected_e2e_ms": 106895.63, + "expected_avg_denoise_ms": 2550.47, + "expected_median_denoise_ms": 756.59 }, "wan2_1_t2v_14b_2gpu": { "stages_ms": { diff --git a/python/sglang/multimodal_gen/test/server/test_server_utils.py b/python/sglang/multimodal_gen/test/server/test_server_utils.py index 8817ec0942f6..3c1b419cf7ed 100644 --- a/python/sglang/multimodal_gen/test/server/test_server_utils.py +++ b/python/sglang/multimodal_gen/test/server/test_server_utils.py @@ -132,7 +132,7 @@ def start(self) -> ServerContext: env["SGLANG_PERF_LOG_DIR"] = log_dir.as_posix() # TODO: unify with run_command - print(f"Running command: {shlex.join(command)}") + logger.info(f"Running command: {shlex.join(command)}") process = subprocess.Popen( command, diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index c781ec838841..83cf91ef8c78 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -369,6 +369,25 @@ def from_req_perf_record( custom_validator="video", num_gpus=2, ), + DiffusionTestCase( + id="qwen_image_t2i_2_gpus", + model_path="Qwen/Qwen-Image", + modality="image", + prompt="A futuristic cityscape at sunset with flying cars", + output_size="1024x1024", + warmup_text=1, + warmup_edit=0, + num_gpus=2, + ), + DiffusionTestCase( + id="flux_image_t2i_2_gpus", + model_path="black-forest-labs/FLUX.1-dev", + modality="image", + prompt="A futuristic cityscape at sunset with flying cars", + output_size="1024x1024", + warmup_text=1, + warmup_edit=0, + ), ] # Load global configuration diff --git a/python/sglang/multimodal_gen/test/test_utils.py b/python/sglang/multimodal_gen/test/test_utils.py index 33e60a91921d..9446680a19fb 100644 --- a/python/sglang/multimodal_gen/test/test_utils.py +++ b/python/sglang/multimodal_gen/test/test_utils.py @@ -385,8 +385,6 @@ def test_single_gpu(self): def test_cfg_parallel(self): """cfg parallel""" - if self.data_type == DataType.IMAGE: - return self._run_test( name=f"{self.model_name()}_cfg_parallel", args="--num-gpus 2 --enable-cfg-parallel", @@ -396,8 +394,6 @@ def test_cfg_parallel(self): def test_usp(self): """usp""" - if self.data_type == DataType.IMAGE: - return self._run_test( name=f"{self.model_name()}_usp", args="--num-gpus 4 --ulysses-degree=2 --ring-degree=2", @@ -407,8 +403,6 @@ def test_usp(self): def test_mixed(self): """mixed""" - if self.data_type == DataType.IMAGE: - return self._run_test( name=f"{self.model_name()}_mixed", args="--num-gpus 4 --ulysses-degree=2 --ring-degree=1 --enable-cfg-parallel",