diff --git a/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py b/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py index af9fa9d2a0d8..fae449fe75e7 100644 --- a/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py @@ -38,6 +38,8 @@ class QwenImageVAEConfig(VAEConfig): use_temporal_tiling: bool = False use_parallel_tiling: bool = False + use_parallel_decode: bool = False + def get_vae_scale_factor(self): return 2 ** len(self.arch_config.temperal_downsample) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index ebf2daa5acb5..20b193443ad5 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -47,7 +47,10 @@ apply_flashinfer_rope_qk_inplace, ) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT -from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -826,30 +829,52 @@ def _modulate( shift, scale, gate = mod_params.chunk(3, dim=-1) if index is not None: - actual_batch = x.shape[0] - shift0, shift1 = ( - shift[:actual_batch], - shift[actual_batch : 2 * actual_batch], - ) - scale0, scale1 = ( - scale[:actual_batch], - scale[actual_batch : 2 * actual_batch], - ) - gate0, gate1 = gate[:actual_batch], gate[actual_batch : 2 * actual_batch] - if not x.is_contiguous(): - x = x.contiguous() - if not index.is_contiguous(): - index = index.contiguous() - if is_scale_residual: - if not residual_x.is_contiguous(): - residual_x = residual_x.contiguous() - if not gate_x.is_contiguous(): - gate_x = gate_x.contiguous() - x, residual_out, gate_result = ( - fuse_residual_layernorm_scale_shift_gate_select01_kernel( + # ROCm currently fails to compile the select01 Triton kernel, so + # keep using the torch.where fallback there. + if x.is_cuda and not current_platform.is_hip(): + actual_batch = x.shape[0] + shift0, shift1 = ( + shift[:actual_batch], + shift[actual_batch : 2 * actual_batch], + ) + scale0, scale1 = ( + scale[:actual_batch], + scale[actual_batch : 2 * actual_batch], + ) + gate0, gate1 = ( + gate[:actual_batch], + gate[actual_batch : 2 * actual_batch], + ) + if not x.is_contiguous(): + x = x.contiguous() + if not index.is_contiguous(): + index = index.contiguous() + if is_scale_residual: + if not residual_x.is_contiguous(): + residual_x = residual_x.contiguous() + if not gate_x.is_contiguous(): + gate_x = gate_x.contiguous() + x, residual_out, gate_result = ( + fuse_residual_layernorm_scale_shift_gate_select01_kernel( + x, + residual=residual_x, + residual_gate=gate_x, + weight=getattr(norm_module.norm, "weight", None), + bias=getattr(norm_module.norm, "bias", None), + scale0=scale0.contiguous(), + shift0=shift0.contiguous(), + gate0=gate0.contiguous(), + scale1=scale1.contiguous(), + shift1=shift1.contiguous(), + gate1=gate1.contiguous(), + index=index, + eps=norm_module.eps, + ) + ) + return x, residual_out, gate_result + else: + x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel( x, - residual=residual_x, - residual_gate=gate_x, weight=getattr(norm_module.norm, "weight", None), bias=getattr(norm_module.norm, "bias", None), scale0=scale0.contiguous(), @@ -861,39 +886,45 @@ def _modulate( index=index, eps=norm_module.eps, ) - ) - return x, residual_out, gate_result + return x, gate_result else: - x, gate_result = fuse_layernorm_scale_shift_gate_select01_kernel( - x, - weight=getattr(norm_module.norm, "weight", None), - bias=getattr(norm_module.norm, "bias", None), - scale0=scale0.contiguous(), - shift0=shift0.contiguous(), - gate0=gate0.contiguous(), - scale1=scale1.contiguous(), - shift1=shift1.contiguous(), - gate1=gate1.contiguous(), - index=index, - eps=norm_module.eps, + actual_batch = x.shape[0] + shift0, shift1 = ( + shift[:actual_batch], + shift[actual_batch : 2 * actual_batch], ) - return x, gate_result + scale0, scale1 = ( + scale[:actual_batch], + scale[actual_batch : 2 * actual_batch], + ) + gate0, gate1 = ( + gate[:actual_batch], + gate[actual_batch : 2 * actual_batch], + ) + index = index.to(dtype=torch.bool).unsqueeze(-1) + shift_result = torch.where( + index, shift1.unsqueeze(1), shift0.unsqueeze(1) + ) + scale_result = torch.where( + index, scale1.unsqueeze(1), scale0.unsqueeze(1) + ) + gate_result = torch.where(index, gate1.unsqueeze(1), gate0.unsqueeze(1)) else: shift_result = shift.unsqueeze(1) scale_result = scale.unsqueeze(1) gate_result = gate.unsqueeze(1) - if is_scale_residual: - modulated, residual_out = norm_module( - residual=residual_x, - x=x, - gate=gate_x, - shift=shift_result, - scale=scale_result, - ) - return modulated, residual_out, gate_result - else: - modulated = norm_module(x=x, shift=shift_result, scale=scale_result) - return modulated, gate_result + if is_scale_residual: + modulated, residual_out = norm_module( + residual=residual_x, + x=x, + gate=gate_x, + shift=shift_result, + scale=scale_result, + ) + return modulated, residual_out, gate_result + else: + modulated = norm_module(x=x, shift=shift_result, scale=scale_result) + return modulated, gate_result def forward( self, @@ -1127,8 +1158,8 @@ def build_modulate_index(self, img_shapes: tuple[int, int, int], device): first_size = sample[0][0] * sample[0][1] * sample[0][2] total_size = sum(s[0] * s[1] * s[2] for s in sample) if sp_world_size > 1: - first_local_size = _local_seq_len(first_size) - tail_local_size = _local_seq_len(total_size - first_size) + first_local_size = _local_seq_len(first_size, sp_world_size) + tail_local_size = _local_seq_len(total_size - first_size, sp_world_size) idx = torch.cat( [ torch.zeros(first_local_size, device=device, dtype=torch.int), diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py index 42c5426d7434..3178783c4362 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from diffusers.models.activations import get_activation @@ -13,7 +14,10 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig -from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_sp_world_size, +) from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -789,6 +793,7 @@ def __init__( self.input_channels = config.arch_config.input_channels self.latents_mean = config.arch_config.latents_mean self.config = config.arch_config + self.use_parallel_decode = config.use_parallel_decode self.encoder = QwenImageEncoder3d( base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, @@ -841,6 +846,8 @@ def __init__( .to(cuda_device, dtype) ) + + def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -956,30 +963,43 @@ def encode( return posterior - def _decode(self, z: torch.Tensor, return_dict: bool = True): + def _decode_with_parallel_dispatch(self, z: torch.Tensor) -> DecoderOutput: + if self.use_parallel_decode and get_sp_world_size() > 1: + num_frame = z.shape[2] + num_sample_frames = (num_frame - 1) * self.temporal_compression_ratio + 1 + decoded = super().parallel_tiled_decode(z)[:, :, :num_sample_frames] + return DecoderOutput(sample=decoded) + + return DecoderOutput(sample=self._decode(z)) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: _, _, num_frame, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, return_dict=return_dict) + return self.tiled_decode(z).sample self.clear_cache() x = self.post_quant_conv(z) for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) else: - out_ = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out_ = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) out = torch.cat([out, out_], 2) - out = torch.clamp(out, min=-1.0, max=1.0) self.clear_cache() - if not return_dict: - return (out,) - - return DecoderOutput(sample=out) + return out def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" @@ -996,29 +1016,121 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp returned. """ if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [ + self._decode_with_parallel_dispatch(z_slice).sample + for z_slice in z.split(1) + ] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z).sample + decoded = self._decode_with_parallel_dispatch(z).sample return decoded - def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) - for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( - y / blend_extent - ) + if blend_extent <= 0: + return b + weight = ( + torch.arange(blend_extent, device=b.device, dtype=b.dtype) / blend_extent + ).view(1, 1, 1, blend_extent, 1) + b[:, :, :, :blend_extent, :] = ( + a[:, :, :, -blend_extent:, :] * (1 - weight) + + b[:, :, :, :blend_extent, :] * weight + ) return b - def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) - for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( - x / blend_extent - ) + if blend_extent <= 0: + return b + weight = ( + torch.arange(blend_extent, device=b.device, dtype=b.dtype) / blend_extent + ).view(1, 1, 1, 1, blend_extent) + b[:, :, :, :, :blend_extent] = ( + a[:, :, :, :, -blend_extent:] * (1 - weight) + + b[:, :, :, :, :blend_extent] * weight + ) return b + def _process_parallel_tiled_outputs( + self, + results: torch.Tensor, + local_dim_metadata: list[torch.Size], + z: torch.Tensor, + world_size: int, + rank: int, + num_t_tiles: int, + num_h_tiles: int, + num_w_tiles: int, + total_spatial_tiles: int, + blend_height: int, + blend_width: int, + ) -> torch.Tensor: + local_size = torch.tensor( + [results.size(0)], device=results.device, dtype=torch.int64 + ) + if rank == 0: + gathered_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] + else: + gathered_sizes = None + dist.gather(local_size, gather_list=gathered_sizes, dst=0) + + max_size = 0 + if rank == 0: + max_size = max(size.item() for size in gathered_sizes) + + max_size_tensor = torch.tensor( + [max_size], device=results.device, dtype=torch.int64 + ) + dist.broadcast(max_size_tensor, src=0) + max_size = int(max_size_tensor.item()) + + padded_results = torch.zeros( + max_size, device=results.device, dtype=results.dtype + ) + padded_results[: results.size(0)] = results + + gathered_dim_metadata = [None] * world_size + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + + if rank == 0: + gathered_results = [ + torch.empty_like(padded_results) for _ in range(world_size) + ] + else: + gathered_results = None + dist.gather(padded_results, gather_list=gathered_results, dst=0) + + if rank == 0: + gathered_results = torch.stack(gathered_results, dim=0).contiguous() + dec = super()._merge_parallel_tiled_results( + gathered_results, + gathered_dim_metadata, + num_t_tiles, + num_h_tiles, + num_w_tiles, + total_spatial_tiles, + blend_height, + blend_width, + ) + shape_tensor = torch.tensor(dec.shape, device=dec.device, dtype=torch.int64) + else: + dec = None + shape_tensor = torch.zeros(5, device=z.device, dtype=torch.int64) + + dist.broadcast(shape_tensor, src=0) + if rank != 0: + dec = z.new_empty(tuple(shape_tensor.tolist())) + dist.broadcast(dec, src=0) + return dec + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/common.py b/python/sglang/multimodal_gen/runtime/models/vaes/common.py index 095ce49574f5..7a55c8c4a8a4 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/common.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/common.py @@ -220,14 +220,110 @@ def _parallel_data_generator( _start_shape += mul_shape global_idx += 1 + def _merge_parallel_tiled_results( + self, + gathered_results: torch.Tensor, + gathered_dim_metadata: list[list[torch.Size]], + num_t_tiles: int, + num_h_tiles: int, + num_w_tiles: int, + total_spatial_tiles: int, + blend_height: int, + blend_width: int, + ) -> torch.Tensor: + data: list = [ + [[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] + for _ in range(num_t_tiles) + ] + for current_data, global_idx in self._parallel_data_generator( + gathered_results, gathered_dim_metadata + ): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + data[t_idx][h_idx][w_idx] = current_data + + result_slices = [] + last_slice_data = None + for i, tem_data in enumerate(data): + slice_data = self._merge_spatial_tiles( + tem_data, + blend_height, + blend_width, + self.tile_sample_stride_height, + self.tile_sample_stride_width, + ) + if i > 0: + slice_data = self.blend_t( + last_slice_data, slice_data, self.blend_num_frames + ) + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames, :, :] + ) + else: + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :] + ) + last_slice_data = slice_data + return torch.cat(result_slices, dim=2) + + def _process_parallel_tiled_outputs( + self, + results: torch.Tensor, + local_dim_metadata: list[torch.Size], + z: torch.Tensor, + world_size: int, + rank: int, + num_t_tiles: int, + num_h_tiles: int, + num_w_tiles: int, + total_spatial_tiles: int, + blend_height: int, + blend_width: int, + ) -> torch.Tensor: + local_size = torch.tensor( + [results.size(0)], device=results.device, dtype=torch.int64 + ) + all_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] + dist.all_gather(all_sizes, local_size) + max_size = max(size.item() for size in all_sizes) + + padded_results = torch.zeros( + max_size, device=results.device, dtype=results.dtype + ) + padded_results[: results.size(0)] = results + + gathered_dim_metadata = [None] * world_size + gathered_results = ( + torch.zeros_like(padded_results) + .repeat(world_size, *[1] * len(padded_results.shape)) + .contiguous() + ) + dist.all_gather_into_tensor(gathered_results, padded_results) + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + gathered_dim_metadata = cast(list[list[torch.Size]], gathered_dim_metadata) + return self._merge_parallel_tiled_results( + gathered_results, + gathered_dim_metadata, + num_t_tiles, + num_h_tiles, + num_w_tiles, + total_spatial_tiles, + blend_height, + blend_width, + ) + def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: """ Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs """ world_size, rank = get_sp_world_size(), get_sp_parallel_rank() - B, C, T, H, W = z.shape + _, _, T, H, W = z.shape - # Calculate parameters tile_latent_min_height = ( self.tile_sample_min_height // self.spatial_compression_ratio ) @@ -259,26 +355,22 @@ def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: total_spatial_tiles = num_h_tiles * num_w_tiles total_tiles = num_t_tiles * total_spatial_tiles - # Calculate tiles per rank and padding tiles_per_rank = (total_tiles + world_size - 1) // world_size start_tile_idx = rank * tiles_per_rank end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles) local_results = [] local_dim_metadata = [] - # Process assigned tiles - for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)): + for global_idx in range(start_tile_idx, end_tile_idx): t_idx = global_idx // total_spatial_tiles spatial_idx = global_idx % total_spatial_tiles h_idx = spatial_idx // num_w_tiles w_idx = spatial_idx % num_w_tiles - # Calculate positions t_start = t_idx * tile_latent_stride_num_frames h_start = h_idx * tile_latent_stride_height w_start = w_idx * tile_latent_stride_width - # Extract and process tile tile = z[ :, :, @@ -286,84 +378,31 @@ def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: h_start : h_start + tile_latent_min_height, w_start : w_start + tile_latent_min_width, ] - - # Process tile - tile = self._decode(tile) - + decoded_tile = self._decode(tile) if t_start > 0: - tile = tile[:, :, 1:, :, :] - - # Store metadata - shape = tile.shape - # Store decoded data (flattened) - decoded_flat = tile.reshape(-1) - local_results.append(decoded_flat) - local_dim_metadata.append(shape) + decoded_tile = decoded_tile[:, :, 1:, :, :] + local_results.append(decoded_tile.reshape(-1)) + local_dim_metadata.append(decoded_tile.shape) - results = torch.cat(local_results, dim=0).contiguous() + if local_results: + results = torch.cat(local_results, dim=0).contiguous() + else: + results = z.new_empty((0,), dtype=z.dtype) del local_results - # first gather size to pad the results - local_size = torch.tensor( - [results.size(0)], device=results.device, dtype=torch.int64 - ) - all_sizes = [ - torch.zeros(1, device=results.device, dtype=torch.int64) - for _ in range(world_size) - ] - dist.all_gather(all_sizes, local_size) - max_size = max(size.item() for size in all_sizes) - padded_results = torch.zeros(max_size, device=results.device) - padded_results[: results.size(0)] = results - del results - - # Gather all results - gathered_dim_metadata = [None] * world_size - gathered_results = ( - torch.zeros_like(padded_results) - .repeat(world_size, *[1] * len(padded_results.shape)) - .contiguous() - ) # use contiguous to make sure it won't copy data in the following operations - # TODO (PY): use sgl_diffusion distributed methods - dist.all_gather_into_tensor(gathered_results, padded_results) - dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) - # Process gathered results - data: list = [ - [[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] - for _ in range(num_t_tiles) - ] - for current_data, global_idx in self._parallel_data_generator( - gathered_results, gathered_dim_metadata - ): - t_idx = global_idx // total_spatial_tiles - spatial_idx = global_idx % total_spatial_tiles - h_idx = spatial_idx // num_w_tiles - w_idx = spatial_idx % num_w_tiles - data[t_idx][h_idx][w_idx] = current_data - # Merge results - result_slices = [] - last_slice_data = None - for i, tem_data in enumerate(data): - slice_data = self._merge_spatial_tiles( - tem_data, - blend_height, - blend_width, - self.tile_sample_stride_height, - self.tile_sample_stride_width, - ) - if i > 0: - slice_data = self.blend_t( - last_slice_data, slice_data, self.blend_num_frames - ) - result_slices.append( - slice_data[:, :, : self.tile_sample_stride_num_frames, :, :] - ) - else: - result_slices.append( - slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :] - ) - last_slice_data = slice_data - dec = torch.cat(result_slices, dim=2) + dec = self._process_parallel_tiled_outputs( + results, + local_dim_metadata, + z, + world_size, + rank, + num_t_tiles, + num_h_tiles, + num_w_tiles, + total_spatial_tiles, + blend_height, + blend_width, + ) return dec def _merge_spatial_tiles(