diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py index 5d3437247f3f..c363d00b7419 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py @@ -102,6 +102,14 @@ def shard_rotary_emb_for_sp(emb): return emb +def maybe_unpad_latents(latents, batch): + # If SP padding was applied, remove extra tokens before reshaping + target_tokens = batch.raw_latent_shape[-1] * batch.raw_latent_shape[-2] + if latents.shape[1] > target_tokens: + latents = latents[:, :target_tokens, :] + return latents + + # config for a single pipeline @dataclass class PipelineConfig: @@ -310,6 +318,7 @@ def get_neg_prompt_embeds(self, batch): return batch.negative_prompt_embeds def post_denoising_loop(self, latents, batch): + latents = maybe_unpad_latents(latents, batch) return latents def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): @@ -655,10 +664,7 @@ def _unpad_and_unpack_latents(self, latents, batch): height = 2 * (int(batch.height) // (vae_scale_factor * 2)) width = 2 * (int(batch.width) // (vae_scale_factor * 2)) - # If SP padding was applied, remove extra tokens before reshaping - target_tokens = (height // 2) * (width // 2) - if latents.shape[1] > target_tokens: - latents = latents[:, :target_tokens, :] + latents = maybe_unpad_latents(latents, batch) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 07ab8950d59f..3beb3c0d1af8 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -228,7 +228,7 @@ def _adjust( logger.debug(f"Setting num_frames to 1 because this is an image-gen model") self.num_frames = 1 self.data_type = DataType.IMAGE - else: + elif self.adjust_frames: # NOTE: We must apply adjust_num_frames BEFORE the SP alignment logic below. # If we apply it after, adjust_num_frames might modify the frame count # and break the divisibility constraint (alignment) required by num_gpus. @@ -536,8 +536,8 @@ def add_cli_args(parser: Any) -> Any: default=SamplingParams.adjust_frames, help=( "Enable/disable adjusting num_frames to evenly split latent frames across GPUs " - "and satisfy model temporal constraints. Default: true. " - "Examples: --adjust-frames, --adjust-frames true, --adjust-frames false." + "and satisfy model temporal constraints. If disabled, tokens might be padded for SP." + "Default: true. Examples: --adjust-frames, --adjust-frames true, --adjust-frames false." ), ) return parser diff --git a/python/sglang/multimodal_gen/runtime/layers/usp.py b/python/sglang/multimodal_gen/runtime/layers/usp.py index 4f3804c91af1..9fa8403a3c36 100644 --- a/python/sglang/multimodal_gen/runtime/layers/usp.py +++ b/python/sglang/multimodal_gen/runtime/layers/usp.py @@ -51,10 +51,10 @@ def _usp_input_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: Perform Ulysses-style input all-to-all over the head dimension. Default layout expects heads at dim=1 and sequence at dim=2: - [b, h, s_local, d] -> [b, h // world_size, s_global, d] + [b, h, s_local, d] -> [b, h_local, s_global, d] If heads are at dim=2 (input is [b, s_local, h, d]), set head_dim=2, and the - function returns [b, s_global, h // world_size, d], preserving the original + function returns [b, s_global, h+local, d], preserving the original head/sequence dim ordering. Args: @@ -83,11 +83,11 @@ def _usp_input_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: h % world_size == 0 ), f"h ({h}) must be divisible by world_size ({world_size})" - # [b, h, s, d] -> [h, b, s, d] + # [b, h, s_local, d] -> [h, b, s_local, d] x_c = x_c.permute(1, 0, 2, 3).contiguous() # all-to-all along h x_c = _usp_all_to_all_single(x_c) - # -> [b, h // world, s * world, d] + # -> [b, h_local, s, d] x_c = ( x_c.reshape(world_size, h // world_size, b, -1, d) .permute(2, 1, 0, 3, 4) @@ -109,7 +109,7 @@ def _usp_output_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: Perform Ulysses-style output all-to-all over the head dimension (inverse of input). Default layout expects heads at dim=1 and sequence at dim=2: - [b, h // world_size, s_global, d] -> [b, h, s_local, d] + [b, h_local, s, d] -> [b, h, s_local, d] If heads are at dim=2 (input is [b, s_global, h // world_size, d]), set head_dim=2, and the function returns [b, s_local, h, d], preserving the original head/sequence @@ -141,10 +141,10 @@ def _usp_output_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: s % world_size == 0 ), f"s ({s}) must be divisible by world_size ({world_size})" - # [b, h, s, d] -> [s, b, h, d] + # [b, h_local, s, d] -> [s, b, h_local, d] x_c = x_c.permute(2, 0, 1, 3).contiguous() x_c = _usp_all_to_all_single(x_c) - # -> [b, h * world, s // world, d] + # -> [b, h, s_local, d] x_c = ( x_c.reshape(world_size, s // world_size, b, -1, d) .permute(2, 0, 3, 1, 4)