From 4c7f60fec99012ec153aac5f43f258a8af1324be Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 01:30:35 +0800 Subject: [PATCH 01/39] diffusion(z-image): add VAE patch parallel decode Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/data.py | 28 +++ .../models/z_image/pipeline_z_image.py | 198 +++++++++++++++++- 2 files changed, 225 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 1a98462aaea..db92873f134 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -45,6 +45,9 @@ class DiffusionParallelConfig: cfg_parallel_size: int = 1 """Number of Classifier Free Guidance (CFG) parallel groups.""" + vae_patch_parallel_size: int = 1 + """Number of ranks used for VAE patch/tile parallelism (decode/encode).""" + @model_validator(mode="after") def _validate_parallel_config(self) -> Self: """Validates the config relationships among the parallel strategies.""" @@ -55,6 +58,7 @@ def _validate_parallel_config(self) -> Self: assert self.ulysses_degree > 0, "Ulysses degree must be > 0" assert self.ring_degree > 0, "Ring degree must be > 0" assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" + assert self.vae_patch_parallel_size > 0, "VAE patch parallel size must be > 0" assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( "Sequence parallel size must be equal to the product of ulysses degree and ring degree," f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" @@ -64,6 +68,21 @@ def _validate_parallel_config(self) -> Self: def __post_init__(self) -> None: if self.sequence_parallel_size is None: self.sequence_parallel_size = self.ulysses_degree * self.ring_degree + + env_override = os.environ.get("VLLM_DIFFUSION_VAE_PATCH_PARALLEL_SIZE") + if env_override is not None and self.vae_patch_parallel_size == 1: + try: + env_value = int(env_override) + if env_value > 0: + self.vae_patch_parallel_size = env_value + else: + logger.warning( + "Ignoring invalid VLLM_DIFFUSION_VAE_PATCH_PARALLEL_SIZE=%r (must be > 0).", + env_override, + ) + except ValueError: + logger.warning("Ignoring invalid VLLM_DIFFUSION_VAE_PATCH_PARALLEL_SIZE=%r.", env_override) + self.world_size = ( self.pipeline_parallel_size * self.data_parallel_size @@ -352,6 +371,9 @@ class OmniDiffusionConfig: supports_multimodal_inputs: bool = False # Logging + enable_vae_profiling: bool = False + """Enable lightweight VAE decode profiling logs.""" + log_level: str = "info" # Omni configuration (injected from stage config) @@ -448,6 +470,12 @@ def __post_init__(self): elif self.max_cpu_loras < 1: raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") + if not self.enable_vae_profiling: + # Optional backdoor for quick experimentation without changing code/config. + env_override = os.getenv("VLLM_DIFFUSION_PROFILE_VAE") + if env_override is not None and env_override.strip().lower() not in ("0", "false", "off", "no"): + self.enable_vae_profiling = True + def update_multimodal_support(self) -> None: self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index b92d3de885f..17d1d699929 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -18,10 +18,12 @@ import inspect import json import os +import time from collections.abc import Callable, Iterable from typing import Any import torch +import torch.distributed as dist import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders import AutoencoderKL @@ -29,9 +31,11 @@ from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from transformers import AutoModel, AutoTokenizer +from vllm.logger import init_logger as init_vllm_logger from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import get_dit_group from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.z_image.z_image_transformer import ( @@ -43,6 +47,7 @@ ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +vllm_logger = init_vllm_logger(__name__) def get_post_process_func( @@ -217,6 +222,157 @@ def encode_prompt( negative_prompt_embeds = [] return prompt_embeds, negative_prompt_embeds + def _decode_latents_with_vae_patch_parallelism(self, latents: torch.Tensor) -> torch.Tensor: + """Decode latents with optional VAE patch/tile parallelism. + + - Only enabled when torch.distributed is initialized and VAE tiling is enabled. + - Falls back to the original `vae.decode` on any unsupported condition or error. + """ + vae_pp_size = getattr(self.od_config.parallel_config, "vae_patch_parallel_size", 1) + + if vae_pp_size <= 1 or not dist.is_initialized(): + return self.vae.decode(latents, return_dict=False)[0] + + # Only parallelize when the baseline already uses tiled_decode semantics. + if not getattr(self.vae, "use_tiling", False): + return self.vae.decode(latents, return_dict=False)[0] + + try: + return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) + except Exception as exc: + vllm_logger.warning("VAE patch parallel decode failed; falling back to vae.decode. %s", exc, exc_info=True) + return self.vae.decode(latents, return_dict=False)[0] + + def _distributed_tiled_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: int) -> torch.Tensor: + """Distributed version of diffusers AutoencoderKL.tiled_decode (decode only). + + Each rank decodes a subset of tiles; rank0 gathers all tiles and performs the + original blend + stitch logic. Non-rank0 ranks return an empty tensor (their + output is ignored by the scheduler). + """ + group = get_dit_group() + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + pp_size = min(int(vae_patch_parallel_size), int(world_size)) + if pp_size <= 1: + return self.vae.decode(z, return_dict=False)[0] + + tile_latent_min_size = getattr(self.vae, "tile_latent_min_size", None) + tile_overlap_factor = getattr(self.vae, "tile_overlap_factor", None) + tile_sample_min_size = getattr(self.vae, "tile_sample_min_size", None) + if tile_latent_min_size is None or tile_overlap_factor is None or tile_sample_min_size is None: + return self.vae.decode(z, return_dict=False)[0] + + overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) + if overlap_size <= 0: + return self.vae.decode(z, return_dict=False)[0] + + h_starts = list(range(0, z.shape[2], overlap_size)) + w_starts = list(range(0, z.shape[3], overlap_size)) + num_rows = len(h_starts) + num_cols = len(w_starts) + num_tiles = num_rows * num_cols + + if num_tiles < 2: + return self.vae.decode(z, return_dict=False)[0] + + blend_extent = int(tile_sample_min_size * tile_overlap_factor) + row_limit = int(tile_sample_min_size - blend_extent) + + # Decide which ranks actively decode tiles. + active = rank < pp_size + + local_tiles: list[torch.Tensor] = [] + local_meta: list[tuple[int, int, int]] = [] + + tile_id = 0 + for i in h_starts: + for j in w_starts: + if active and (tile_id % pp_size == rank): + tile = z[:, :, i : i + tile_latent_min_size, j : j + tile_latent_min_size] + if getattr(self.vae.config, "use_post_quant_conv", False): + tile = self.vae.post_quant_conv(tile) + decoded = self.vae.decoder(tile) + local_tiles.append(decoded) + local_meta.append((tile_id, int(decoded.shape[-2]), int(decoded.shape[-1]))) + tile_id += 1 + + # Gather per-rank tile counts. + count_tensor = torch.tensor([len(local_tiles)], device=z.device, dtype=torch.int64) + if rank == 0: + count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)] + else: + count_gather = None + dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group) + max_count = 0 + if rank == 0: + counts = [int(t.item()) for t in count_gather] # type: ignore[arg-type] + max_count = max(counts) if counts else 0 + max_count_tensor = torch.tensor([max_count], device=z.device, dtype=torch.int64) + dist.broadcast(max_count_tensor, src=0, group=group) + max_count = int(max_count_tensor.item()) + + # Prepare padded metadata + tiles for gather. + meta_tensor = torch.full((max_count, 3), -1, device=z.device, dtype=torch.int64) + tile_tensor = torch.zeros( + (max_count, z.shape[0], 3, tile_sample_min_size, tile_sample_min_size), + device=z.device, + dtype=z.dtype, + ) + for idx, (tile_id, h, w) in enumerate(local_meta): + meta_tensor[idx, 0] = tile_id + meta_tensor[idx, 1] = h + meta_tensor[idx, 2] = w + tile_tensor[idx, :, :, :h, :w] = local_tiles[idx] + + if rank == 0: + meta_gather = [torch.empty_like(meta_tensor) for _ in range(world_size)] + tile_gather = [torch.empty_like(tile_tensor) for _ in range(world_size)] + else: + meta_gather = None + tile_gather = None + + dist.gather(meta_tensor, gather_list=meta_gather, dst=0, group=group) + dist.gather(tile_tensor, gather_list=tile_gather, dst=0, group=group) + + if rank != 0: + return torch.empty(0, device=z.device, dtype=z.dtype) + + # Reconstruct the full tile grid on rank0. + tile_map: dict[int, torch.Tensor] = {} + for src_rank in range(world_size): + meta_src = meta_gather[src_rank] # type: ignore[index] + tiles_src = tile_gather[src_rank] # type: ignore[index] + for idx in range(max_count): + tid = int(meta_src[idx, 0].item()) + if tid < 0: + continue + h = int(meta_src[idx, 1].item()) + w = int(meta_src[idx, 2].item()) + tile_map[tid] = tiles_src[idx, :, :, :h, :w] + + rows: list[list[torch.Tensor]] = [] + for r in range(num_rows): + row: list[torch.Tensor] = [] + for c in range(num_cols): + tid = r * num_cols + c + row.append(tile_map[tid]) + rows.append(row) + + result_rows: list[torch.Tensor] = [] + for i, row in enumerate(rows): + result_row: list[torch.Tensor] = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.vae.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.vae.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + return torch.cat(result_rows, dim=2) + def _encode_prompt( self, prompt: str | list[str], @@ -607,7 +763,47 @@ def forward( latents = latents.to(self.vae.dtype) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] + profile_vae = self.od_config.enable_vae_profiling + if profile_vae: + device = latents.device + dist_rank = None + dist_world_size = None + if dist.is_initialized(): + try: + dist_rank = dist.get_rank() + dist_world_size = dist.get_world_size() + except Exception: + dist_rank = None + dist_world_size = None + if device.type == "cuda": + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.synchronize(device) + t0 = time.perf_counter() + image = self._decode_latents_with_vae_patch_parallelism(latents) + if device.type == "cuda": + torch.cuda.synchronize(device) + t1 = time.perf_counter() + if device.type == "cuda": + peak_alloc_gib = torch.cuda.max_memory_allocated(device) / (1024**3) + peak_resv_gib = torch.cuda.max_memory_reserved(device) / (1024**3) + vllm_logger.debug( + "Z-Image VAE decode profile: rank=%s/%s time_ms=%.3f " + "peak_alloc_gib=%.3f peak_reserved_gib=%.3f", + dist_rank if dist_rank is not None else "na", + dist_world_size if dist_world_size is not None else "na", + (t1 - t0) * 1000, + peak_alloc_gib, + peak_resv_gib, + ) + else: + vllm_logger.debug( + "Z-Image VAE decode profile: rank=%s/%s time_ms=%.3f", + dist_rank if dist_rank is not None else "na", + dist_world_size if dist_world_size is not None else "na", + (t1 - t0) * 1000, + ) + else: + image = self._decode_latents_with_vae_patch_parallelism(latents) # image = self.image_processor.postprocess(image, output_type=output_type) return DiffusionOutput(output=image) From 5bddc7cc59a3dddafc5ab21036c55adb665a8ffa Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 01:30:35 +0800 Subject: [PATCH 02/39] tests/examples: cover VAE patch parallelism Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../text_to_image/text_to_image.py | 27 +++++- .../test_zimage_tensor_parallel.py | 95 +++++++++++++++++-- 2 files changed, 112 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 428c2437e6f..8119fc2358f 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -11,7 +11,7 @@ from vllm_omni.diffusion.data import DiffusionParallelConfig, logger from vllm_omni.entrypoints.omni import Omni from vllm_omni.outputs import OmniRequestOutput -from vllm_omni.utils.platform_utils import detect_device_type, is_npu +from vllm_omni.utils.platform_utils import detect_device_type def parse_args() -> argparse.Namespace: @@ -101,12 +101,28 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable CPU offloading for diffusion models.", ) + parser.add_argument( + "--vae_use_slicing", + action="store_true", + help="Enable VAE slicing (memory optimization).", + ) + parser.add_argument( + "--vae_use_tiling", + action="store_true", + help="Enable VAE tiling (memory optimization).", + ) parser.add_argument( "--tensor_parallel_size", type=int, default=1, help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", ) + parser.add_argument( + "--vae_patch_parallel_size", + type=int, + default=1, + help="Number of ranks used for VAE patch/tile parallelism (decode/encode).", + ) return parser.parse_args() @@ -116,8 +132,8 @@ def main(): generator = torch.Generator(device=device).manual_seed(args.seed) # Enable VAE memory optimizations on NPU - vae_use_slicing = is_npu() - vae_use_tiling = is_npu() + vae_use_slicing = args.vae_use_slicing or device == "npu" + vae_use_tiling = args.vae_use_tiling or device == "npu" # Configure cache based on backend type cache_config = None @@ -154,6 +170,7 @@ def main(): ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size, tensor_parallel_size=args.tensor_parallel_size, + vae_patch_parallel_size=args.vae_patch_parallel_size, ) # Check if profiling is requested via environment variable @@ -182,8 +199,10 @@ def main(): print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") print( f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, " - f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, " + f"vae_patch_parallel_size={args.vae_patch_parallel_size}" ) + print(f" CPU offload: {args.enable_cpu_offload}") print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py index a9fdec4dc0b..f864f9829fd 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -66,17 +66,38 @@ def _extract_single_image(outputs) -> Image.Image: return images[0] +def _get_enforce_eager_for_cuda() -> bool: + cc_major, _cc_minor = torch.cuda.get_device_capability(0) + return cc_major < 8 + + def _run_zimage_generate( - *, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int + *, + tp_size: int, + height: int, + width: int, + num_inference_steps: int, + seed: int, + enforce_eager: bool, + vae_use_tiling: bool = False, + vae_patch_parallel_size: int = 1, + num_requests: int = 4, ) -> tuple[Image.Image, float, float]: + if num_requests < 2: + raise ValueError("num_requests must be >= 2 (1 warmup + >=1 timed)") + torch.cuda.empty_cache() device_index = torch.cuda.current_device() monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) monitor.start() - m = Omni( model=_get_zimage_model(), - parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size), + parallel_config=DiffusionParallelConfig( + tensor_parallel_size=tp_size, + vae_patch_parallel_size=vae_patch_parallel_size, + ), + enforce_eager=enforce_eager, + vae_use_tiling=vae_use_tiling, ) try: # NOTE: Omni closes itself when a generate() call is exhausted. @@ -87,7 +108,6 @@ def _run_zimage_generate( # This also serves as a warmup: the first output may include extra # compilation/caching overhead, while later outputs are closer to # steady-state inference. - num_requests = 4 # 1 warmup + 3 timed gen = m.generate( [PROMPT] * num_requests, height=height, @@ -120,6 +140,7 @@ def _run_zimage_generate( return _extract_single_image([last_output]), median_time_s, peak_memory_mb finally: monitor.stop() + m.close() cleanup_dist_env_and_memory() @@ -130,8 +151,10 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.") - height = 512 - width = 512 + enforce_eager = _get_enforce_eager_for_cuda() + + height = 256 + width = 256 num_inference_steps = 2 seed = 42 @@ -141,6 +164,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): width=width, num_inference_steps=num_inference_steps, seed=seed, + enforce_eager=enforce_eager, ) tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate( tp_size=2, @@ -148,6 +172,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): width=width, num_inference_steps=num_inference_steps, seed=seed, + enforce_eager=enforce_eager, ) tp1_path = tmp_path / "zimage_tp1.png" @@ -179,3 +204,61 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): assert tp2_peak_mem < tp1_peak_mem, ( f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})" ) + + +@pytest.mark.integration +def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): + if is_npu() or is_rocm(): + pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA for now.") + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.") + + enforce_eager = _get_enforce_eager_for_cuda() + + # Use a larger image to ensure there are multiple VAE tiles. + height = 768 + width = 768 + num_inference_steps = 2 + seed = 42 + + baseline_img, _baseline_time_s, _baseline_peak_mem = _run_zimage_generate( + tp_size=2, + height=height, + width=width, + num_inference_steps=num_inference_steps, + seed=seed, + enforce_eager=enforce_eager, + vae_use_tiling=True, + vae_patch_parallel_size=1, + num_requests=2, + ) + pp2_img, _pp2_time_s, _pp2_peak_mem = _run_zimage_generate( + tp_size=2, + height=height, + width=width, + num_inference_steps=num_inference_steps, + seed=seed, + enforce_eager=enforce_eager, + vae_use_tiling=True, + vae_patch_parallel_size=2, + num_requests=2, + ) + + baseline_path = tmp_path / "zimage_tp2_vae_pp1.png" + pp2_path = tmp_path / "zimage_tp2_vae_pp2.png" + baseline_img.save(baseline_path) + pp2_img.save(pp2_path) + + mean_abs_diff, max_abs_diff = _diff_metrics(baseline_img, pp2_img) + mean_threshold = 5e-3 + max_threshold = 1e-1 + print( + "Z-Image VAE patch parallel image diff stats (TP=2, pp=1 vs pp=2): " + f"mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e}; " + f"thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}; " + f"pp1_img={baseline_path}, pp2_img={pp2_path}" + ) + assert mean_abs_diff <= mean_threshold and max_abs_diff <= max_threshold, ( + f"Image diff exceeded threshold: mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e} " + f"(thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e})" + ) From f72343fa9af7e768f1f104ba1161520663e4f31c Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 01:30:35 +0800 Subject: [PATCH 03/39] examples: keep VAE slicing opt-in Co-authored-by: hsliuustc0106 Signed-off-by: dongbo910220 <1275604947@qq.com> --- examples/offline_inference/text_to_image/text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 8119fc2358f..f66806fd31b 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -132,7 +132,7 @@ def main(): generator = torch.Generator(device=device).manual_seed(args.seed) # Enable VAE memory optimizations on NPU - vae_use_slicing = args.vae_use_slicing or device == "npu" + vae_use_slicing = args.vae_use_slicing vae_use_tiling = args.vae_use_tiling or device == "npu" # Configure cache based on backend type From 4aa881984242a4e17bacb16860c36c08ae18769e Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 01:30:35 +0800 Subject: [PATCH 04/39] examples: keep VAE tiling opt-in Co-authored-by: hsliuustc0106 Signed-off-by: dongbo910220 <1275604947@qq.com> --- examples/offline_inference/text_to_image/text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index f66806fd31b..f0fcca29679 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -133,7 +133,7 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = args.vae_use_slicing - vae_use_tiling = args.vae_use_tiling or device == "npu" + vae_use_tiling = args.vae_use_tiling # Configure cache based on backend type cache_config = None From 01dc18cc76b275d7a12b138fb5e9ecc07b0a085b Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 02:30:52 +0800 Subject: [PATCH 05/39] diffusion(z-image): prototype DistVAE-style patch decode Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../models/z_image/pipeline_z_image.py | 158 +++++++++++++++++- 1 file changed, 157 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 17d1d699929..691bf120d99 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -17,6 +17,7 @@ import inspect import json +import math import os import time from collections.abc import Callable, Iterable @@ -238,7 +239,20 @@ def _decode_latents_with_vae_patch_parallelism(self, latents: torch.Tensor) -> t return self.vae.decode(latents, return_dict=False)[0] try: - return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) + tile_latent_min_size = getattr(self.vae, "tile_latent_min_size", None) + if tile_latent_min_size is None: + return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) + + # Match diffusers' condition for when VAE tiling would be used. + should_tile = (latents.shape[-1] > tile_latent_min_size) or (latents.shape[-2] > tile_latent_min_size) + if should_tile: + return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) + + # For the boundary size where tiling wouldn't kick in, decode one spatial block per rank. + if (latents.shape[-1] >= tile_latent_min_size) and (latents.shape[-2] >= tile_latent_min_size): + return self._distributed_patch_decode(latents, vae_patch_parallel_size=vae_pp_size) + + return self.vae.decode(latents, return_dict=False)[0] except Exception as exc: vllm_logger.warning("VAE patch parallel decode failed; falling back to vae.decode. %s", exc, exc_info=True) return self.vae.decode(latents, return_dict=False)[0] @@ -373,6 +387,148 @@ def _distributed_tiled_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: return torch.cat(result_rows, dim=2) + @staticmethod + def _factor_pp_grid(pp_size: int) -> tuple[int, int]: + """Pick a (rows, cols) grid whose product equals `pp_size`.""" + if pp_size <= 1: + return (1, 1) + root = int(math.sqrt(pp_size)) + for rows in range(root, 0, -1): + if pp_size % rows == 0: + return (rows, pp_size // rows) + return (1, pp_size) + + def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: int) -> torch.Tensor: + """Decode one spatial block per rank, then stitch RGB blocks on rank0. + + This is a DistVAE-style variant intended for cases where diffusers tiling would + not kick in (e.g., `tile_latent_min_size >= z.shape[-2/-1]`) but we still want + to reduce the per-rank VAE decode activation peak. + + Each active rank decodes exactly one core block with an optional latent-space + halo. The halo is cropped away before gathering, so rank0 only stitches the + final RGB blocks. + """ + group = get_dit_group() + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + pp_size = min(int(vae_patch_parallel_size), int(world_size)) + if pp_size <= 1: + return self.vae.decode(z, return_dict=False)[0] + + # Only ranks < pp_size participate in decoding. Others send empty tensors. + active = rank < pp_size + + bsz, _, latent_h, latent_w = z.shape + scale = int(self.vae_scale_factor) + out_h = latent_h * scale + out_w = latent_w * scale + + local_core = torch.empty(0, device=z.device, dtype=z.dtype) + local_h = 0 + local_w = 0 + + if active: + grid_rows, grid_cols = self._factor_pp_grid(pp_size) + patch_idx = rank + patch_row = patch_idx // grid_cols + patch_col = patch_idx % grid_cols + + h0 = (patch_row * latent_h) // grid_rows + h1 = ((patch_row + 1) * latent_h) // grid_rows + w0 = (patch_col * latent_w) // grid_cols + w1 = ((patch_col + 1) * latent_w) // grid_cols + + core_h = max(0, h1 - h0) + core_w = max(0, w1 - w0) + if core_h == 0 or core_w == 0: + local_core = torch.empty(0, device=z.device, dtype=z.dtype) + else: + overlap_factor = float(getattr(self.vae, "tile_overlap_factor", 0.25)) + halo = int(min(core_h, core_w) * overlap_factor) + halo = max(0, halo) + + ph0 = max(0, h0 - halo) + ph1 = min(latent_h, h1 + halo) + pw0 = max(0, w0 - halo) + pw1 = min(latent_w, w1 + halo) + + tile = z[:, :, ph0:ph1, pw0:pw1] + if getattr(self.vae.config, "use_post_quant_conv", False): + tile = self.vae.post_quant_conv(tile) + decoded = self.vae.decoder(tile) + + ch0 = (h0 - ph0) * scale + cw0 = (w0 - pw0) * scale + ch1 = ch0 + core_h * scale + cw1 = cw0 + core_w * scale + + local_core = decoded[:, :, ch0:ch1, cw0:cw1] + + local_h = int(local_core.shape[-2]) if local_core.numel() else 0 + local_w = int(local_core.shape[-1]) if local_core.numel() else 0 + + # Gather block shapes. + shape_tensor = torch.tensor([local_h, local_w], device=z.device, dtype=torch.int64) + if rank == 0: + shape_gather = [torch.empty_like(shape_tensor) for _ in range(world_size)] + else: + shape_gather = None + dist.gather(shape_tensor, gather_list=shape_gather, dst=0, group=group) + + max_h = 0 + max_w = 0 + if rank == 0: + shapes = [tuple(int(x.item()) for x in t) for t in shape_gather] # type: ignore[arg-type] + max_h = max((h for h, _ in shapes), default=0) + max_w = max((w for _, w in shapes), default=0) + + max_hw_tensor = torch.tensor([max_h, max_w], device=z.device, dtype=torch.int64) + dist.broadcast(max_hw_tensor, src=0, group=group) + max_h = int(max_hw_tensor[0].item()) + max_w = int(max_hw_tensor[1].item()) + + # Pad local block for gather. + if max_h == 0 or max_w == 0: + padded = torch.empty(0, device=z.device, dtype=z.dtype) + else: + padded = torch.zeros((bsz, 3, max_h, max_w), device=z.device, dtype=z.dtype) + if local_h and local_w: + padded[:, :, :local_h, :local_w] = local_core + + if rank == 0: + block_gather = [torch.empty_like(padded) for _ in range(world_size)] + else: + block_gather = None + dist.gather(padded, gather_list=block_gather, dst=0, group=group) + + if rank != 0: + return torch.empty(0, device=z.device, dtype=z.dtype) + + # Stitch on rank0. + out = torch.empty((bsz, 3, out_h, out_w), device=z.device, dtype=z.dtype) + + grid_rows, grid_cols = self._factor_pp_grid(pp_size) + for patch_idx in range(pp_size): + src_rank = patch_idx + patch_row = patch_idx // grid_cols + patch_col = patch_idx % grid_cols + h0 = (patch_row * latent_h) // grid_rows + h1 = ((patch_row + 1) * latent_h) // grid_rows + w0 = (patch_col * latent_w) // grid_cols + w1 = ((patch_col + 1) * latent_w) // grid_cols + + ph = (h1 - h0) * scale + pw = (w1 - w0) * scale + if ph <= 0 or pw <= 0: + continue + + tile = block_gather[src_rank] # type: ignore[index] + out[:, :, h0 * scale : h1 * scale, w0 * scale : w1 * scale] = tile[:, :, :ph, :pw] + + return out + def _encode_prompt( self, prompt: str | list[str], From dfccf30b14c4646a4f6ce37f19e15132ac2bcf89 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 05:19:31 +0800 Subject: [PATCH 06/39] diffusion(z-image): extend patch decode and balance tiles Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../test_zimage_tensor_parallel.py | 6 ++- .../models/z_image/pipeline_z_image.py | 42 ++++++++++--------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py index f864f9829fd..2907588653c 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -216,8 +216,10 @@ def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): enforce_eager = _get_enforce_eager_for_cuda() # Use a larger image to ensure there are multiple VAE tiles. - height = 768 - width = 768 + # For Z-Image-Turbo, VAE tiling kicks in when latent_h/latent_w > 128. + # 1152x1152 -> latent 144x144. + height = 1152 + width = 1152 num_inference_steps = 2 seed = 42 diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 691bf120d99..1663bcfa012 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -248,11 +248,9 @@ def _decode_latents_with_vae_patch_parallelism(self, latents: torch.Tensor) -> t if should_tile: return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) - # For the boundary size where tiling wouldn't kick in, decode one spatial block per rank. - if (latents.shape[-1] >= tile_latent_min_size) and (latents.shape[-2] >= tile_latent_min_size): - return self._distributed_patch_decode(latents, vae_patch_parallel_size=vae_pp_size) - - return self.vae.decode(latents, return_dict=False)[0] + # For cases where diffusers tiling would not kick in, decode overlapped + # spatial patches per rank and blend/stitch on rank0. + return self._distributed_patch_decode(latents, vae_patch_parallel_size=vae_pp_size) except Exception as exc: vllm_logger.warning("VAE patch parallel decode failed; falling back to vae.decode. %s", exc, exc_info=True) return self.vae.decode(latents, return_dict=False)[0] @@ -303,7 +301,9 @@ def _distributed_tiled_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: tile_id = 0 for i in h_starts: for j in w_starts: - if active and (tile_id % pp_size == rank): + # Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile. + tile_rank = (tile_id + 1) % pp_size + if active and (tile_rank == rank): tile = z[:, :, i : i + tile_latent_min_size, j : j + tile_latent_min_size] if getattr(self.vae.config, "use_post_quant_conv", False): tile = self.vae.post_quant_conv(tile) @@ -401,13 +401,10 @@ def _factor_pp_grid(pp_size: int) -> tuple[int, int]: def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: int) -> torch.Tensor: """Decode one spatial block per rank, then stitch RGB blocks on rank0. - This is a DistVAE-style variant intended for cases where diffusers tiling would - not kick in (e.g., `tile_latent_min_size >= z.shape[-2/-1]`) but we still want - to reduce the per-rank VAE decode activation peak. - - Each active rank decodes exactly one core block with an optional latent-space - halo. The halo is cropped away before gathering, so rank0 only stitches the - final RGB blocks. + Intended for sizes where diffusers tiling would not kick in, so we can still + reduce the per-rank VAE decode activation peak. Each rank decodes a core + block with a small latent-space halo, then crops to the core and gathers the + RGB blocks to rank0 for final stitching. """ group = get_dit_group() world_size = dist.get_world_size(group) @@ -417,6 +414,14 @@ def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: if pp_size <= 1: return self.vae.decode(z, return_dict=False)[0] + tile_latent_min_size = getattr(self.vae, "tile_latent_min_size", None) + tile_overlap_factor = getattr(self.vae, "tile_overlap_factor", None) + if tile_latent_min_size is None or tile_overlap_factor is None: + return self.vae.decode(z, return_dict=False)[0] + + overlap_latent = int(tile_latent_min_size * float(tile_overlap_factor)) + halo_base = max(0, overlap_latent // 2) + # Only ranks < pp_size participate in decoding. Others send empty tensors. active = rank < pp_size @@ -429,8 +434,9 @@ def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: local_h = 0 local_w = 0 + grid_rows, grid_cols = self._factor_pp_grid(pp_size) + if active: - grid_rows, grid_cols = self._factor_pp_grid(pp_size) patch_idx = rank patch_row = patch_idx // grid_cols patch_col = patch_idx % grid_cols @@ -445,10 +451,7 @@ def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: if core_h == 0 or core_w == 0: local_core = torch.empty(0, device=z.device, dtype=z.dtype) else: - overlap_factor = float(getattr(self.vae, "tile_overlap_factor", 0.25)) - halo = int(min(core_h, core_w) * overlap_factor) - halo = max(0, halo) - + halo = max(halo_base, min(core_h, core_w) // 2) ph0 = max(0, h0 - halo) ph1 = min(latent_h, h1 + halo) pw0 = max(0, w0 - halo) @@ -463,7 +466,6 @@ def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: cw0 = (w0 - pw0) * scale ch1 = ch0 + core_h * scale cw1 = cw0 + core_w * scale - local_core = decoded[:, :, ch0:ch1, cw0:cw1] local_h = int(local_core.shape[-2]) if local_core.numel() else 0 @@ -509,11 +511,11 @@ def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: # Stitch on rank0. out = torch.empty((bsz, 3, out_h, out_w), device=z.device, dtype=z.dtype) - grid_rows, grid_cols = self._factor_pp_grid(pp_size) for patch_idx in range(pp_size): src_rank = patch_idx patch_row = patch_idx // grid_cols patch_col = patch_idx % grid_cols + h0 = (patch_row * latent_h) // grid_rows h1 = ((patch_row + 1) * latent_h) // grid_rows w0 = (patch_col * latent_w) // grid_cols From 4997ab9ca11c2b3a5bd63d1513631ea1bee44d68 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 05:21:41 +0800 Subject: [PATCH 07/39] diffusion(z-image): clarify patch decode comment Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/models/z_image/pipeline_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 1663bcfa012..baf972a49a5 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -248,8 +248,8 @@ def _decode_latents_with_vae_patch_parallelism(self, latents: torch.Tensor) -> t if should_tile: return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) - # For cases where diffusers tiling would not kick in, decode overlapped - # spatial patches per rank and blend/stitch on rank0. + # For cases where diffusers tiling would not kick in, decode spatial + # patches per rank and stitch on rank0. return self._distributed_patch_decode(latents, vae_patch_parallel_size=vae_pp_size) except Exception as exc: vllm_logger.warning("VAE patch parallel decode failed; falling back to vae.decode. %s", exc, exc_info=True) From 3bb5cca108b86e95e0d85a01e8efd5504b411552 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 14 Jan 2026 16:14:53 +0800 Subject: [PATCH 08/39] diffusion: refactor VAE patch parallelism (ref_dit) Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../models/z_image/pipeline_z_image.py | 314 +----------- vllm_omni/diffusion/registry.py | 9 + vllm_omni/diffusion/vae/__init__.py | 2 + vllm_omni/diffusion/vae/patch_parallelism.py | 451 ++++++++++++++++++ 4 files changed, 464 insertions(+), 312 deletions(-) create mode 100644 vllm_omni/diffusion/vae/__init__.py create mode 100644 vllm_omni/diffusion/vae/patch_parallelism.py diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index baf972a49a5..2d5e832c0e2 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -17,7 +17,6 @@ import inspect import json -import math import os import time from collections.abc import Callable, Iterable @@ -36,7 +35,6 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import get_dit_group from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.z_image.z_image_transformer import ( @@ -223,314 +221,6 @@ def encode_prompt( negative_prompt_embeds = [] return prompt_embeds, negative_prompt_embeds - def _decode_latents_with_vae_patch_parallelism(self, latents: torch.Tensor) -> torch.Tensor: - """Decode latents with optional VAE patch/tile parallelism. - - - Only enabled when torch.distributed is initialized and VAE tiling is enabled. - - Falls back to the original `vae.decode` on any unsupported condition or error. - """ - vae_pp_size = getattr(self.od_config.parallel_config, "vae_patch_parallel_size", 1) - - if vae_pp_size <= 1 or not dist.is_initialized(): - return self.vae.decode(latents, return_dict=False)[0] - - # Only parallelize when the baseline already uses tiled_decode semantics. - if not getattr(self.vae, "use_tiling", False): - return self.vae.decode(latents, return_dict=False)[0] - - try: - tile_latent_min_size = getattr(self.vae, "tile_latent_min_size", None) - if tile_latent_min_size is None: - return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) - - # Match diffusers' condition for when VAE tiling would be used. - should_tile = (latents.shape[-1] > tile_latent_min_size) or (latents.shape[-2] > tile_latent_min_size) - if should_tile: - return self._distributed_tiled_decode(latents, vae_patch_parallel_size=vae_pp_size) - - # For cases where diffusers tiling would not kick in, decode spatial - # patches per rank and stitch on rank0. - return self._distributed_patch_decode(latents, vae_patch_parallel_size=vae_pp_size) - except Exception as exc: - vllm_logger.warning("VAE patch parallel decode failed; falling back to vae.decode. %s", exc, exc_info=True) - return self.vae.decode(latents, return_dict=False)[0] - - def _distributed_tiled_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: int) -> torch.Tensor: - """Distributed version of diffusers AutoencoderKL.tiled_decode (decode only). - - Each rank decodes a subset of tiles; rank0 gathers all tiles and performs the - original blend + stitch logic. Non-rank0 ranks return an empty tensor (their - output is ignored by the scheduler). - """ - group = get_dit_group() - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - pp_size = min(int(vae_patch_parallel_size), int(world_size)) - if pp_size <= 1: - return self.vae.decode(z, return_dict=False)[0] - - tile_latent_min_size = getattr(self.vae, "tile_latent_min_size", None) - tile_overlap_factor = getattr(self.vae, "tile_overlap_factor", None) - tile_sample_min_size = getattr(self.vae, "tile_sample_min_size", None) - if tile_latent_min_size is None or tile_overlap_factor is None or tile_sample_min_size is None: - return self.vae.decode(z, return_dict=False)[0] - - overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) - if overlap_size <= 0: - return self.vae.decode(z, return_dict=False)[0] - - h_starts = list(range(0, z.shape[2], overlap_size)) - w_starts = list(range(0, z.shape[3], overlap_size)) - num_rows = len(h_starts) - num_cols = len(w_starts) - num_tiles = num_rows * num_cols - - if num_tiles < 2: - return self.vae.decode(z, return_dict=False)[0] - - blend_extent = int(tile_sample_min_size * tile_overlap_factor) - row_limit = int(tile_sample_min_size - blend_extent) - - # Decide which ranks actively decode tiles. - active = rank < pp_size - - local_tiles: list[torch.Tensor] = [] - local_meta: list[tuple[int, int, int]] = [] - - tile_id = 0 - for i in h_starts: - for j in w_starts: - # Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile. - tile_rank = (tile_id + 1) % pp_size - if active and (tile_rank == rank): - tile = z[:, :, i : i + tile_latent_min_size, j : j + tile_latent_min_size] - if getattr(self.vae.config, "use_post_quant_conv", False): - tile = self.vae.post_quant_conv(tile) - decoded = self.vae.decoder(tile) - local_tiles.append(decoded) - local_meta.append((tile_id, int(decoded.shape[-2]), int(decoded.shape[-1]))) - tile_id += 1 - - # Gather per-rank tile counts. - count_tensor = torch.tensor([len(local_tiles)], device=z.device, dtype=torch.int64) - if rank == 0: - count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)] - else: - count_gather = None - dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group) - max_count = 0 - if rank == 0: - counts = [int(t.item()) for t in count_gather] # type: ignore[arg-type] - max_count = max(counts) if counts else 0 - max_count_tensor = torch.tensor([max_count], device=z.device, dtype=torch.int64) - dist.broadcast(max_count_tensor, src=0, group=group) - max_count = int(max_count_tensor.item()) - - # Prepare padded metadata + tiles for gather. - meta_tensor = torch.full((max_count, 3), -1, device=z.device, dtype=torch.int64) - tile_tensor = torch.zeros( - (max_count, z.shape[0], 3, tile_sample_min_size, tile_sample_min_size), - device=z.device, - dtype=z.dtype, - ) - for idx, (tile_id, h, w) in enumerate(local_meta): - meta_tensor[idx, 0] = tile_id - meta_tensor[idx, 1] = h - meta_tensor[idx, 2] = w - tile_tensor[idx, :, :, :h, :w] = local_tiles[idx] - - if rank == 0: - meta_gather = [torch.empty_like(meta_tensor) for _ in range(world_size)] - tile_gather = [torch.empty_like(tile_tensor) for _ in range(world_size)] - else: - meta_gather = None - tile_gather = None - - dist.gather(meta_tensor, gather_list=meta_gather, dst=0, group=group) - dist.gather(tile_tensor, gather_list=tile_gather, dst=0, group=group) - - if rank != 0: - return torch.empty(0, device=z.device, dtype=z.dtype) - - # Reconstruct the full tile grid on rank0. - tile_map: dict[int, torch.Tensor] = {} - for src_rank in range(world_size): - meta_src = meta_gather[src_rank] # type: ignore[index] - tiles_src = tile_gather[src_rank] # type: ignore[index] - for idx in range(max_count): - tid = int(meta_src[idx, 0].item()) - if tid < 0: - continue - h = int(meta_src[idx, 1].item()) - w = int(meta_src[idx, 2].item()) - tile_map[tid] = tiles_src[idx, :, :, :h, :w] - - rows: list[list[torch.Tensor]] = [] - for r in range(num_rows): - row: list[torch.Tensor] = [] - for c in range(num_cols): - tid = r * num_cols + c - row.append(tile_map[tid]) - rows.append(row) - - result_rows: list[torch.Tensor] = [] - for i, row in enumerate(rows): - result_row: list[torch.Tensor] = [] - for j, tile in enumerate(row): - if i > 0: - tile = self.vae.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.vae.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=3)) - - return torch.cat(result_rows, dim=2) - - @staticmethod - def _factor_pp_grid(pp_size: int) -> tuple[int, int]: - """Pick a (rows, cols) grid whose product equals `pp_size`.""" - if pp_size <= 1: - return (1, 1) - root = int(math.sqrt(pp_size)) - for rows in range(root, 0, -1): - if pp_size % rows == 0: - return (rows, pp_size // rows) - return (1, pp_size) - - def _distributed_patch_decode(self, z: torch.Tensor, *, vae_patch_parallel_size: int) -> torch.Tensor: - """Decode one spatial block per rank, then stitch RGB blocks on rank0. - - Intended for sizes where diffusers tiling would not kick in, so we can still - reduce the per-rank VAE decode activation peak. Each rank decodes a core - block with a small latent-space halo, then crops to the core and gathers the - RGB blocks to rank0 for final stitching. - """ - group = get_dit_group() - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - pp_size = min(int(vae_patch_parallel_size), int(world_size)) - if pp_size <= 1: - return self.vae.decode(z, return_dict=False)[0] - - tile_latent_min_size = getattr(self.vae, "tile_latent_min_size", None) - tile_overlap_factor = getattr(self.vae, "tile_overlap_factor", None) - if tile_latent_min_size is None or tile_overlap_factor is None: - return self.vae.decode(z, return_dict=False)[0] - - overlap_latent = int(tile_latent_min_size * float(tile_overlap_factor)) - halo_base = max(0, overlap_latent // 2) - - # Only ranks < pp_size participate in decoding. Others send empty tensors. - active = rank < pp_size - - bsz, _, latent_h, latent_w = z.shape - scale = int(self.vae_scale_factor) - out_h = latent_h * scale - out_w = latent_w * scale - - local_core = torch.empty(0, device=z.device, dtype=z.dtype) - local_h = 0 - local_w = 0 - - grid_rows, grid_cols = self._factor_pp_grid(pp_size) - - if active: - patch_idx = rank - patch_row = patch_idx // grid_cols - patch_col = patch_idx % grid_cols - - h0 = (patch_row * latent_h) // grid_rows - h1 = ((patch_row + 1) * latent_h) // grid_rows - w0 = (patch_col * latent_w) // grid_cols - w1 = ((patch_col + 1) * latent_w) // grid_cols - - core_h = max(0, h1 - h0) - core_w = max(0, w1 - w0) - if core_h == 0 or core_w == 0: - local_core = torch.empty(0, device=z.device, dtype=z.dtype) - else: - halo = max(halo_base, min(core_h, core_w) // 2) - ph0 = max(0, h0 - halo) - ph1 = min(latent_h, h1 + halo) - pw0 = max(0, w0 - halo) - pw1 = min(latent_w, w1 + halo) - - tile = z[:, :, ph0:ph1, pw0:pw1] - if getattr(self.vae.config, "use_post_quant_conv", False): - tile = self.vae.post_quant_conv(tile) - decoded = self.vae.decoder(tile) - - ch0 = (h0 - ph0) * scale - cw0 = (w0 - pw0) * scale - ch1 = ch0 + core_h * scale - cw1 = cw0 + core_w * scale - local_core = decoded[:, :, ch0:ch1, cw0:cw1] - - local_h = int(local_core.shape[-2]) if local_core.numel() else 0 - local_w = int(local_core.shape[-1]) if local_core.numel() else 0 - - # Gather block shapes. - shape_tensor = torch.tensor([local_h, local_w], device=z.device, dtype=torch.int64) - if rank == 0: - shape_gather = [torch.empty_like(shape_tensor) for _ in range(world_size)] - else: - shape_gather = None - dist.gather(shape_tensor, gather_list=shape_gather, dst=0, group=group) - - max_h = 0 - max_w = 0 - if rank == 0: - shapes = [tuple(int(x.item()) for x in t) for t in shape_gather] # type: ignore[arg-type] - max_h = max((h for h, _ in shapes), default=0) - max_w = max((w for _, w in shapes), default=0) - - max_hw_tensor = torch.tensor([max_h, max_w], device=z.device, dtype=torch.int64) - dist.broadcast(max_hw_tensor, src=0, group=group) - max_h = int(max_hw_tensor[0].item()) - max_w = int(max_hw_tensor[1].item()) - - # Pad local block for gather. - if max_h == 0 or max_w == 0: - padded = torch.empty(0, device=z.device, dtype=z.dtype) - else: - padded = torch.zeros((bsz, 3, max_h, max_w), device=z.device, dtype=z.dtype) - if local_h and local_w: - padded[:, :, :local_h, :local_w] = local_core - - if rank == 0: - block_gather = [torch.empty_like(padded) for _ in range(world_size)] - else: - block_gather = None - dist.gather(padded, gather_list=block_gather, dst=0, group=group) - - if rank != 0: - return torch.empty(0, device=z.device, dtype=z.dtype) - - # Stitch on rank0. - out = torch.empty((bsz, 3, out_h, out_w), device=z.device, dtype=z.dtype) - - for patch_idx in range(pp_size): - src_rank = patch_idx - patch_row = patch_idx // grid_cols - patch_col = patch_idx % grid_cols - - h0 = (patch_row * latent_h) // grid_rows - h1 = ((patch_row + 1) * latent_h) // grid_rows - w0 = (patch_col * latent_w) // grid_cols - w1 = ((patch_col + 1) * latent_w) // grid_cols - - ph = (h1 - h0) * scale - pw = (w1 - w0) * scale - if ph <= 0 or pw <= 0: - continue - - tile = block_gather[src_rank] # type: ignore[index] - out[:, :, h0 * scale : h1 * scale, w0 * scale : w1 * scale] = tile[:, :, :ph, :pw] - - return out - def _encode_prompt( self, prompt: str | list[str], @@ -937,7 +627,7 @@ def forward( torch.cuda.reset_peak_memory_stats(device) torch.cuda.synchronize(device) t0 = time.perf_counter() - image = self._decode_latents_with_vae_patch_parallelism(latents) + image = self.vae.decode(latents, return_dict=False)[0] if device.type == "cuda": torch.cuda.synchronize(device) t1 = time.perf_counter() @@ -961,7 +651,7 @@ def forward( (t1 - t0) * 1000, ) else: - image = self._decode_latents_with_vae_patch_parallelism(latents) + image = self.vae.decode(latents, return_dict=False)[0] # image = self.image_processor.postprocess(image, output_type=output_type) return DiffusionOutput(output=image) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index d4429aff676..5d2765d741c 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -132,6 +132,15 @@ def initialize_model( model.vae.use_slicing = od_config.vae_use_slicing if hasattr(model.vae, "use_tiling"): model.vae.use_tiling = od_config.vae_use_tiling + if hasattr(model, "vae") and hasattr(od_config, "parallel_config"): + from vllm_omni.diffusion.distributed.parallel_state import get_dit_group + from vllm_omni.diffusion.vae.patch_parallelism import maybe_install_vae_patch_parallelism + + maybe_install_vae_patch_parallelism( + model, + vae_patch_parallel_size=od_config.parallel_config.vae_patch_parallel_size, + group_getter=get_dit_group, + ) # Apply sequence parallelism if enabled # This follows diffusers' pattern where enable_parallelism() is called diff --git a/vllm_omni/diffusion/vae/__init__.py b/vllm_omni/diffusion/vae/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/vae/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/vae/patch_parallelism.py b/vllm_omni/diffusion/vae/patch_parallelism.py new file mode 100644 index 00000000000..1e41c454a66 --- /dev/null +++ b/vllm_omni/diffusion/vae/patch_parallelism.py @@ -0,0 +1,451 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import math +from collections.abc import Callable +from typing import Any + +import torch +import torch.distributed as dist +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _get_vae_spatial_scale_factor(vae: Any) -> int: + try: + block_out_channels = getattr(getattr(vae, "config", None), "block_out_channels", None) + if block_out_channels: + return 2 ** (len(block_out_channels) - 1) + except Exception: + pass + return 8 + + +def _factor_pp_grid(pp_size: int) -> tuple[int, int]: + """Pick a (rows, cols) grid whose product equals `pp_size`.""" + if pp_size <= 1: + return (1, 1) + root = int(math.sqrt(pp_size)) + for rows in range(root, 0, -1): + if pp_size % rows == 0: + return (rows, pp_size // rows) + return (1, pp_size) + + +def _distributed_tiled_decode( + *, + vae: Any, + orig_decode: Callable[..., Any], + z: torch.Tensor, + group: dist.ProcessGroup, + vae_patch_parallel_size: int, +) -> torch.Tensor: + """Distributed version of diffusers AutoencoderKL.tiled_decode (decode only). + + Each rank decodes a subset of tiles; rank0 gathers all tiles and performs the + original blend + stitch logic. Non-rank0 ranks return an empty tensor; callers + can broadcast the stitched result if needed. + """ + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + pp_size = min(int(vae_patch_parallel_size), int(world_size)) + if pp_size <= 1: + return orig_decode(z, return_dict=False)[0] + + tile_latent_min_size = getattr(vae, "tile_latent_min_size", None) + tile_overlap_factor = getattr(vae, "tile_overlap_factor", None) + tile_sample_min_size = getattr(vae, "tile_sample_min_size", None) + if tile_latent_min_size is None or tile_overlap_factor is None or tile_sample_min_size is None: + return orig_decode(z, return_dict=False)[0] + + overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) + if overlap_size <= 0: + return orig_decode(z, return_dict=False)[0] + + h_starts = list(range(0, z.shape[2], overlap_size)) + w_starts = list(range(0, z.shape[3], overlap_size)) + num_rows = len(h_starts) + num_cols = len(w_starts) + num_tiles = num_rows * num_cols + + if num_tiles < 2: + return orig_decode(z, return_dict=False)[0] + + blend_extent = int(tile_sample_min_size * tile_overlap_factor) + row_limit = int(tile_sample_min_size - blend_extent) + + # Decide which ranks actively decode tiles. + active = rank < pp_size + + local_tiles: list[torch.Tensor] = [] + local_meta: list[tuple[int, int, int]] = [] + + tile_id = 0 + for i in h_starts: + for j in w_starts: + # Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile. + tile_rank = (tile_id + 1) % pp_size + if active and (tile_rank == rank): + tile = z[:, :, i : i + tile_latent_min_size, j : j + tile_latent_min_size] + if getattr(getattr(vae, "config", None), "use_post_quant_conv", False): + tile = vae.post_quant_conv(tile) + decoded = vae.decoder(tile) + local_tiles.append(decoded) + local_meta.append((tile_id, int(decoded.shape[-2]), int(decoded.shape[-1]))) + tile_id += 1 + + # Gather per-rank tile counts. + count_tensor = torch.tensor([len(local_tiles)], device=z.device, dtype=torch.int64) + if rank == 0: + count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)] + else: + count_gather = None + dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group) + max_count = 0 + if rank == 0: + counts = [int(t.item()) for t in count_gather] # type: ignore[arg-type] + max_count = max(counts) if counts else 0 + max_count_tensor = torch.tensor([max_count], device=z.device, dtype=torch.int64) + dist.broadcast(max_count_tensor, src=0, group=group) + max_count = int(max_count_tensor.item()) + + out_channels = int(getattr(getattr(vae, "config", None), "out_channels", 3)) + + # Prepare padded metadata + tiles for gather. + meta_tensor = torch.full((max_count, 3), -1, device=z.device, dtype=torch.int64) + tile_tensor = torch.zeros( + (max_count, z.shape[0], out_channels, tile_sample_min_size, tile_sample_min_size), + device=z.device, + dtype=z.dtype, + ) + for idx, (tile_id, h, w) in enumerate(local_meta): + meta_tensor[idx, 0] = tile_id + meta_tensor[idx, 1] = h + meta_tensor[idx, 2] = w + tile_tensor[idx, :, :, :h, :w] = local_tiles[idx] + + if rank == 0: + meta_gather = [torch.empty_like(meta_tensor) for _ in range(world_size)] + tile_gather = [torch.empty_like(tile_tensor) for _ in range(world_size)] + else: + meta_gather = None + tile_gather = None + + dist.gather(meta_tensor, gather_list=meta_gather, dst=0, group=group) + dist.gather(tile_tensor, gather_list=tile_gather, dst=0, group=group) + + if rank != 0: + return torch.empty(0, device=z.device, dtype=z.dtype) + + # Reconstruct the full tile grid on rank0. + tile_map: dict[int, torch.Tensor] = {} + for src_rank in range(world_size): + meta_src = meta_gather[src_rank] # type: ignore[index] + tiles_src = tile_gather[src_rank] # type: ignore[index] + for idx in range(max_count): + tid = int(meta_src[idx, 0].item()) + if tid < 0: + continue + h = int(meta_src[idx, 1].item()) + w = int(meta_src[idx, 2].item()) + tile_map[tid] = tiles_src[idx, :, :, :h, :w] + + rows: list[list[torch.Tensor]] = [] + for r in range(num_rows): + row: list[torch.Tensor] = [] + for c in range(num_cols): + tid = r * num_cols + c + row.append(tile_map[tid]) + rows.append(row) + + result_rows: list[torch.Tensor] = [] + for i, row in enumerate(rows): + result_row: list[torch.Tensor] = [] + for j, tile in enumerate(row): + if i > 0: + tile = vae.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = vae.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + return torch.cat(result_rows, dim=2) + + +def _distributed_patch_decode( + *, + vae: Any, + orig_decode: Callable[..., Any], + z: torch.Tensor, + group: dist.ProcessGroup, + vae_patch_parallel_size: int, + vae_scale_factor: int, +) -> torch.Tensor: + """Decode one spatial block per rank, then stitch RGB blocks on rank0. + + Intended for sizes where diffusers tiling would not kick in, so we can still + reduce the per-rank VAE decode activation peak. Each rank decodes a core + block with a small latent-space halo, then crops to the core and gathers the + RGB blocks to rank0 for final stitching. + """ + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + pp_size = min(int(vae_patch_parallel_size), int(world_size)) + if pp_size <= 1: + return orig_decode(z, return_dict=False)[0] + + tile_latent_min_size = getattr(vae, "tile_latent_min_size", None) + tile_overlap_factor = getattr(vae, "tile_overlap_factor", None) + if tile_latent_min_size is None or tile_overlap_factor is None: + return orig_decode(z, return_dict=False)[0] + + overlap_latent = int(tile_latent_min_size * float(tile_overlap_factor)) + halo_base = max(0, overlap_latent // 2) + + # Only ranks < pp_size participate in decoding. Others send empty tensors. + active = rank < pp_size + + bsz, _, latent_h, latent_w = z.shape + scale = int(vae_scale_factor) + out_h = latent_h * scale + out_w = latent_w * scale + + out_channels = int(getattr(getattr(vae, "config", None), "out_channels", 3)) + + local_core = torch.empty(0, device=z.device, dtype=z.dtype) + local_h = 0 + local_w = 0 + + grid_rows, grid_cols = _factor_pp_grid(pp_size) + + if active: + patch_idx = rank + patch_row = patch_idx // grid_cols + patch_col = patch_idx % grid_cols + + h0 = (patch_row * latent_h) // grid_rows + h1 = ((patch_row + 1) * latent_h) // grid_rows + w0 = (patch_col * latent_w) // grid_cols + w1 = ((patch_col + 1) * latent_w) // grid_cols + + core_h = max(0, h1 - h0) + core_w = max(0, w1 - w0) + if core_h == 0 or core_w == 0: + local_core = torch.empty(0, device=z.device, dtype=z.dtype) + else: + halo = max(halo_base, min(core_h, core_w) // 2) + ph0 = max(0, h0 - halo) + ph1 = min(latent_h, h1 + halo) + pw0 = max(0, w0 - halo) + pw1 = min(latent_w, w1 + halo) + + tile = z[:, :, ph0:ph1, pw0:pw1] + if getattr(getattr(vae, "config", None), "use_post_quant_conv", False): + tile = vae.post_quant_conv(tile) + decoded = vae.decoder(tile) + + ch0 = (h0 - ph0) * scale + cw0 = (w0 - pw0) * scale + ch1 = ch0 + core_h * scale + cw1 = cw0 + core_w * scale + local_core = decoded[:, :, ch0:ch1, cw0:cw1] + + local_h = int(local_core.shape[-2]) if local_core.numel() else 0 + local_w = int(local_core.shape[-1]) if local_core.numel() else 0 + + # Gather block shapes. + shape_tensor = torch.tensor([local_h, local_w], device=z.device, dtype=torch.int64) + if rank == 0: + shape_gather = [torch.empty_like(shape_tensor) for _ in range(world_size)] + else: + shape_gather = None + dist.gather(shape_tensor, gather_list=shape_gather, dst=0, group=group) + + max_h = 0 + max_w = 0 + if rank == 0: + shapes = [tuple(int(x.item()) for x in t) for t in shape_gather] # type: ignore[arg-type] + max_h = max((h for h, _ in shapes), default=0) + max_w = max((w for _, w in shapes), default=0) + + max_hw_tensor = torch.tensor([max_h, max_w], device=z.device, dtype=torch.int64) + dist.broadcast(max_hw_tensor, src=0, group=group) + max_h = int(max_hw_tensor[0].item()) + max_w = int(max_hw_tensor[1].item()) + + # Pad local block for gather. + if max_h == 0 or max_w == 0: + padded = torch.empty(0, device=z.device, dtype=z.dtype) + else: + padded = torch.zeros((bsz, out_channels, max_h, max_w), device=z.device, dtype=z.dtype) + if local_h and local_w: + padded[:, :, :local_h, :local_w] = local_core + + if rank == 0: + block_gather = [torch.empty_like(padded) for _ in range(world_size)] + else: + block_gather = None + dist.gather(padded, gather_list=block_gather, dst=0, group=group) + + if rank != 0: + return torch.empty(0, device=z.device, dtype=z.dtype) + + # Stitch on rank0. + out = torch.empty((bsz, out_channels, out_h, out_w), device=z.device, dtype=z.dtype) + + for patch_idx in range(pp_size): + src_rank = patch_idx + patch_row = patch_idx // grid_cols + patch_col = patch_idx % grid_cols + + h0 = (patch_row * latent_h) // grid_rows + h1 = ((patch_row + 1) * latent_h) // grid_rows + w0 = (patch_col * latent_w) // grid_cols + w1 = ((patch_col + 1) * latent_w) // grid_cols + + ph = (h1 - h0) * scale + pw = (w1 - w0) * scale + if ph <= 0 or pw <= 0: + continue + + tile = block_gather[src_rank] # type: ignore[index] + out[:, :, h0 * scale : h1 * scale, w0 * scale : w1 * scale] = tile[:, :, :ph, :pw] + + return out + + +class VaePatchParallelism: + """Patch/tile-parallel VAE decode wrapper. + + This is meant to be installed as an instance-level override of `vae.decode` + so pipelines don't need model-specific code paths. + """ + + def __init__( + self, + vae: Any, + *, + vae_patch_parallel_size: int, + group_getter: Callable[[], dist.ProcessGroup], + ) -> None: + self._vae = vae + self._vae_patch_parallel_size = int(vae_patch_parallel_size) + self._group_getter = group_getter + + self._vae_scale_factor = _get_vae_spatial_scale_factor(vae) + self._orig_decode = vae.decode + + def decode(self, z: torch.Tensor, return_dict: bool = True, *args: Any, **kwargs: Any): + # Keep the original path for unsupported VAE types / shapes. + if z.ndim != 4: + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + if self._vae_patch_parallel_size <= 1 or not dist.is_initialized(): + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + if not getattr(self._vae, "use_tiling", False): + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + try: + group = self._group_getter() + except Exception: + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + pp_size = min(int(self._vae_patch_parallel_size), int(world_size)) + if pp_size <= 1: + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + # Match diffusers' condition for when VAE tiling would be used. + tile_latent_min_size = getattr(self._vae, "tile_latent_min_size", None) + if tile_latent_min_size is None: + decoded = _distributed_tiled_decode( + vae=self._vae, + orig_decode=self._orig_decode, + z=z, + group=group, + vae_patch_parallel_size=pp_size, + ) + else: + should_tile = (z.shape[-1] > tile_latent_min_size) or (z.shape[-2] > tile_latent_min_size) + if should_tile: + decoded = _distributed_tiled_decode( + vae=self._vae, + orig_decode=self._orig_decode, + z=z, + group=group, + vae_patch_parallel_size=pp_size, + ) + else: + decoded = _distributed_patch_decode( + vae=self._vae, + orig_decode=self._orig_decode, + z=z, + group=group, + vae_patch_parallel_size=pp_size, + vae_scale_factor=self._vae_scale_factor, + ) + + if rank == 0 and decoded.numel() == 0: + logger.warning("VAE patch parallel decode produced empty output on rank0; falling back to vae.decode.") + decoded = self._orig_decode(z, return_dict=False, *args, **kwargs)[0] + if rank == 0 and decoded.dtype != z.dtype: + decoded = decoded.to(dtype=z.dtype) + if rank == 0 and not decoded.is_contiguous(): + decoded = decoded.contiguous() + + shape_tensor = torch.empty((4,), device=z.device, dtype=torch.int64) + if rank == 0: + shape_tensor.copy_(torch.tensor(tuple(decoded.shape), device=z.device, dtype=torch.int64)) + dist.broadcast(shape_tensor, src=0, group=group) + + if rank != 0: + decoded = torch.empty(tuple(int(x) for x in shape_tensor.tolist()), device=z.device, dtype=z.dtype) + dist.broadcast(decoded, src=0, group=group) + + if not return_dict: + return (decoded,) + + from diffusers.models.autoencoders.vae import DecoderOutput + + return DecoderOutput(sample=decoded) + + +def maybe_install_vae_patch_parallelism( + pipeline: Any, + *, + vae_patch_parallel_size: int, + group_getter: Callable[[], dist.ProcessGroup], +) -> None: + """Install patch-parallel VAE decode wrapper onto a diffusers-style pipeline.""" + if vae_patch_parallel_size <= 1: + return + + vae = getattr(pipeline, "vae", None) + if vae is None or not hasattr(vae, "decode"): + return + try: + from diffusers.models.autoencoders import AutoencoderKL + except Exception: + return + if not isinstance(vae, AutoencoderKL): + return + + if getattr(vae, "_vllm_vae_patch_parallel_installed", False): + return + + wrapper = VaePatchParallelism( + vae, + vae_patch_parallel_size=vae_patch_parallel_size, + group_getter=group_getter, + ) + + vae._vllm_vae_patch_parallel_installed = True # type: ignore[attr-defined] + vae._vllm_vae_patch_parallel_original_decode = vae.decode # type: ignore[attr-defined] + vae.decode = wrapper.decode # type: ignore[assignment] From 75c8618eea4500a8edc6d63bfd72411a8d49f320 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Thu, 15 Jan 2026 23:19:58 +0800 Subject: [PATCH 09/39] diffusion: dedupe VAE patch-parallel helpers Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/vae/patch_parallelism.py | 58 ++++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/vllm_omni/diffusion/vae/patch_parallelism.py b/vllm_omni/diffusion/vae/patch_parallelism.py index 1e41c454a66..961e8363a33 100644 --- a/vllm_omni/diffusion/vae/patch_parallelism.py +++ b/vllm_omni/diffusion/vae/patch_parallelism.py @@ -35,6 +35,37 @@ def _factor_pp_grid(pp_size: int) -> tuple[int, int]: return (1, pp_size) +def _get_world_rank_pp_size( + group: dist.ProcessGroup, + vae_patch_parallel_size: int, +) -> tuple[int, int, int]: + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + pp_size = min(int(vae_patch_parallel_size), int(world_size)) + return world_size, rank, pp_size + + +def _get_vae_out_channels(vae: Any) -> int: + return int(getattr(getattr(vae, "config", None), "out_channels", 3)) + + +def _get_vae_tile_params(vae: Any) -> tuple[int, float] | None: + tile_latent_min_size = getattr(vae, "tile_latent_min_size", None) + tile_overlap_factor = getattr(vae, "tile_overlap_factor", None) + if tile_latent_min_size is None or tile_overlap_factor is None: + return None + return int(tile_latent_min_size), float(tile_overlap_factor) + + +def _get_vae_tiling_params(vae: Any) -> tuple[int, float, int] | None: + tile_sample_min_size = getattr(vae, "tile_sample_min_size", None) + tile_params = _get_vae_tile_params(vae) + if tile_params is None or tile_sample_min_size is None: + return None + tile_latent_min_size, tile_overlap_factor = tile_params + return tile_latent_min_size, tile_overlap_factor, int(tile_sample_min_size) + + def _distributed_tiled_decode( *, vae: Any, @@ -49,18 +80,14 @@ def _distributed_tiled_decode( original blend + stitch logic. Non-rank0 ranks return an empty tensor; callers can broadcast the stitched result if needed. """ - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - pp_size = min(int(vae_patch_parallel_size), int(world_size)) + world_size, rank, pp_size = _get_world_rank_pp_size(group, vae_patch_parallel_size) if pp_size <= 1: return orig_decode(z, return_dict=False)[0] - tile_latent_min_size = getattr(vae, "tile_latent_min_size", None) - tile_overlap_factor = getattr(vae, "tile_overlap_factor", None) - tile_sample_min_size = getattr(vae, "tile_sample_min_size", None) - if tile_latent_min_size is None or tile_overlap_factor is None or tile_sample_min_size is None: + tiling_params = _get_vae_tiling_params(vae) + if tiling_params is None: return orig_decode(z, return_dict=False)[0] + tile_latent_min_size, tile_overlap_factor, tile_sample_min_size = tiling_params overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) if overlap_size <= 0: @@ -113,7 +140,7 @@ def _distributed_tiled_decode( dist.broadcast(max_count_tensor, src=0, group=group) max_count = int(max_count_tensor.item()) - out_channels = int(getattr(getattr(vae, "config", None), "out_channels", 3)) + out_channels = _get_vae_out_channels(vae) # Prepare padded metadata + tiles for gather. meta_tensor = torch.full((max_count, 3), -1, device=z.device, dtype=torch.int64) @@ -192,17 +219,14 @@ def _distributed_patch_decode( block with a small latent-space halo, then crops to the core and gathers the RGB blocks to rank0 for final stitching. """ - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - pp_size = min(int(vae_patch_parallel_size), int(world_size)) + world_size, rank, pp_size = _get_world_rank_pp_size(group, vae_patch_parallel_size) if pp_size <= 1: return orig_decode(z, return_dict=False)[0] - tile_latent_min_size = getattr(vae, "tile_latent_min_size", None) - tile_overlap_factor = getattr(vae, "tile_overlap_factor", None) - if tile_latent_min_size is None or tile_overlap_factor is None: + tile_params = _get_vae_tile_params(vae) + if tile_params is None: return orig_decode(z, return_dict=False)[0] + tile_latent_min_size, tile_overlap_factor = tile_params overlap_latent = int(tile_latent_min_size * float(tile_overlap_factor)) halo_base = max(0, overlap_latent // 2) @@ -215,7 +239,7 @@ def _distributed_patch_decode( out_h = latent_h * scale out_w = latent_w * scale - out_channels = int(getattr(getattr(vae, "config", None), "out_channels", 3)) + out_channels = _get_vae_out_channels(vae) local_core = torch.empty(0, device=z.device, dtype=z.dtype) local_h = 0 From a8b0e96d2a54cba0e1b4739fed5ecbafda796b4c Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Thu, 15 Jan 2026 23:29:05 +0800 Subject: [PATCH 10/39] tests: add unit coverage for VAE patch parallelism helpers Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../diffusion/vae/test_patch_parallelism.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 tests/unit/diffusion/vae/test_patch_parallelism.py diff --git a/tests/unit/diffusion/vae/test_patch_parallelism.py b/tests/unit/diffusion/vae/test_patch_parallelism.py new file mode 100644 index 00000000000..fda27de9dbb --- /dev/null +++ b/tests/unit/diffusion/vae/test_patch_parallelism.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm_omni.diffusion.vae import patch_parallelism as pp + + +class _DummyConfig: + def __init__(self, **attrs): + for k, v in attrs.items(): + setattr(self, k, v) + + +class _DummyVae: + def __init__(self, *, config=None, **attrs): + self.config = config + for k, v in attrs.items(): + setattr(self, k, v) + + +def test_get_vae_spatial_scale_factor_uses_block_out_channels_len_minus_1(): + vae = _DummyVae(config=_DummyConfig(block_out_channels=[128, 256, 512, 512])) + assert pp._get_vae_spatial_scale_factor(vae) == 8 + + vae = _DummyVae(config=_DummyConfig(block_out_channels=[1, 2, 3, 4, 5])) + assert pp._get_vae_spatial_scale_factor(vae) == 16 + + +def test_get_vae_spatial_scale_factor_defaults_to_8_on_missing_or_empty(): + assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig())) == 8 + assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig(block_out_channels=[]))) == 8 + assert pp._get_vae_spatial_scale_factor(_DummyVae(config=None)) == 8 + + +def test_get_vae_spatial_scale_factor_defaults_to_8_on_exception(): + class _BrokenConfig: + @property + def block_out_channels(self): + raise RuntimeError("boom") + + assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_BrokenConfig())) == 8 + + +@pytest.mark.parametrize( + ("pp_size", "expected"), + [ + (0, (1, 1)), + (1, (1, 1)), + (2, (1, 2)), + (3, (1, 3)), + (4, (2, 2)), + (6, (2, 3)), + (8, (2, 4)), + (12, (3, 4)), + (16, (4, 4)), + ], +) +def test_factor_pp_grid(pp_size: int, expected: tuple[int, int]): + assert pp._factor_pp_grid(pp_size) == expected + + +def test_get_world_rank_pp_size(monkeypatch): + monkeypatch.setattr(pp.dist, "get_world_size", lambda _: 8) + monkeypatch.setattr(pp.dist, "get_rank", lambda _: 3) + + world_size, rank, pp_size = pp._get_world_rank_pp_size(object(), 4) + assert (world_size, rank, pp_size) == (8, 3, 4) + + world_size, rank, pp_size = pp._get_world_rank_pp_size(object(), 16) + assert (world_size, rank, pp_size) == (8, 3, 8) + + +def test_get_vae_out_channels_defaults_to_3(): + assert pp._get_vae_out_channels(_DummyVae(config=None)) == 3 + assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig())) == 3 + + +def test_get_vae_out_channels_reads_config(): + assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels=4))) == 4 + assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels="5"))) == 5 + + +def test_get_vae_tile_params_returns_none_if_missing(): + assert pp._get_vae_tile_params(_DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25)) is None + assert pp._get_vae_tile_params(_DummyVae(tile_latent_min_size=128, tile_overlap_factor=None)) is None + + +def test_get_vae_tile_params_parses_types(): + vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25") + assert pp._get_vae_tile_params(vae) == (128, 0.25) + + +def test_get_vae_tiling_params_returns_none_if_missing(): + vae = _DummyVae(tile_latent_min_size=128, tile_overlap_factor=0.25, tile_sample_min_size=None) + assert pp._get_vae_tiling_params(vae) is None + + vae = _DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25, tile_sample_min_size=1024) + assert pp._get_vae_tiling_params(vae) is None + + +def test_get_vae_tiling_params_parses_types(): + vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25", tile_sample_min_size="1024") + assert pp._get_vae_tiling_params(vae) == (128, 0.25, 1024) + From 6dd3e3a80c3216f4199029d496415f25aa831970 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Thu, 15 Jan 2026 23:56:23 +0800 Subject: [PATCH 11/39] diffusion: inject VAE decode profiling wrapper Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../diffusion/vae/test_patch_parallelism.py | 1 - .../models/z_image/pipeline_z_image.py | 46 +------ vllm_omni/diffusion/registry.py | 6 + vllm_omni/diffusion/vae/decode_profiler.py | 123 ++++++++++++++++++ 4 files changed, 130 insertions(+), 46 deletions(-) create mode 100644 vllm_omni/diffusion/vae/decode_profiler.py diff --git a/tests/unit/diffusion/vae/test_patch_parallelism.py b/tests/unit/diffusion/vae/test_patch_parallelism.py index fda27de9dbb..033885d9f83 100644 --- a/tests/unit/diffusion/vae/test_patch_parallelism.py +++ b/tests/unit/diffusion/vae/test_patch_parallelism.py @@ -102,4 +102,3 @@ def test_get_vae_tiling_params_returns_none_if_missing(): def test_get_vae_tiling_params_parses_types(): vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25", tile_sample_min_size="1024") assert pp._get_vae_tiling_params(vae) == (128, 0.25, 1024) - diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 2d5e832c0e2..b92d3de885f 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -18,12 +18,10 @@ import inspect import json import os -import time from collections.abc import Callable, Iterable from typing import Any import torch -import torch.distributed as dist import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders import AutoencoderKL @@ -31,7 +29,6 @@ from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from transformers import AutoModel, AutoTokenizer -from vllm.logger import init_logger as init_vllm_logger from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig @@ -46,7 +43,6 @@ ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name -vllm_logger = init_vllm_logger(__name__) def get_post_process_func( @@ -611,47 +607,7 @@ def forward( latents = latents.to(self.vae.dtype) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - profile_vae = self.od_config.enable_vae_profiling - if profile_vae: - device = latents.device - dist_rank = None - dist_world_size = None - if dist.is_initialized(): - try: - dist_rank = dist.get_rank() - dist_world_size = dist.get_world_size() - except Exception: - dist_rank = None - dist_world_size = None - if device.type == "cuda": - torch.cuda.reset_peak_memory_stats(device) - torch.cuda.synchronize(device) - t0 = time.perf_counter() - image = self.vae.decode(latents, return_dict=False)[0] - if device.type == "cuda": - torch.cuda.synchronize(device) - t1 = time.perf_counter() - if device.type == "cuda": - peak_alloc_gib = torch.cuda.max_memory_allocated(device) / (1024**3) - peak_resv_gib = torch.cuda.max_memory_reserved(device) / (1024**3) - vllm_logger.debug( - "Z-Image VAE decode profile: rank=%s/%s time_ms=%.3f " - "peak_alloc_gib=%.3f peak_reserved_gib=%.3f", - dist_rank if dist_rank is not None else "na", - dist_world_size if dist_world_size is not None else "na", - (t1 - t0) * 1000, - peak_alloc_gib, - peak_resv_gib, - ) - else: - vllm_logger.debug( - "Z-Image VAE decode profile: rank=%s/%s time_ms=%.3f", - dist_rank if dist_rank is not None else "na", - dist_world_size if dist_world_size is not None else "na", - (t1 - t0) * 1000, - ) - else: - image = self.vae.decode(latents, return_dict=False)[0] + image = self.vae.decode(latents, return_dict=False)[0] # image = self.image_processor.postprocess(image, output_type=output_type) return DiffusionOutput(output=image) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 5d2765d741c..8c7f6b175d0 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -134,6 +134,7 @@ def initialize_model( model.vae.use_tiling = od_config.vae_use_tiling if hasattr(model, "vae") and hasattr(od_config, "parallel_config"): from vllm_omni.diffusion.distributed.parallel_state import get_dit_group + from vllm_omni.diffusion.vae.decode_profiler import maybe_install_vae_decode_profiler from vllm_omni.diffusion.vae.patch_parallelism import maybe_install_vae_patch_parallelism maybe_install_vae_patch_parallelism( @@ -141,6 +142,11 @@ def initialize_model( vae_patch_parallel_size=od_config.parallel_config.vae_patch_parallel_size, group_getter=get_dit_group, ) + maybe_install_vae_decode_profiler( + model, + enabled=od_config.enable_vae_profiling, + group_getter=get_dit_group, + ) # Apply sequence parallelism if enabled # This follows diffusers' pattern where enable_parallelism() is called diff --git a/vllm_omni/diffusion/vae/decode_profiler.py b/vllm_omni/diffusion/vae/decode_profiler.py new file mode 100644 index 00000000000..d1088af5ecf --- /dev/null +++ b/vllm_omni/diffusion/vae/decode_profiler.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import time +from collections.abc import Callable +from typing import Any + +import torch +import torch.distributed as dist +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _get_rank_world( + group_getter: Callable[[], dist.ProcessGroup] | None, +) -> tuple[int | None, int | None]: + if not dist.is_initialized(): + return None, None + + group: dist.ProcessGroup | None = None + if group_getter is not None: + try: + group = group_getter() + except Exception: + group = None + + try: + if group is None: + return dist.get_rank(), dist.get_world_size() + return dist.get_rank(group), dist.get_world_size(group) + except Exception: + return None, None + + +class VaeDecodeProfiler: + """Lightweight VAE decode profiler wrapper. + + This is meant to be installed as an instance-level override of `vae.decode` + so pipelines don't need model-specific code paths. + """ + + def __init__( + self, + vae: Any, + *, + label: str, + group_getter: Callable[[], dist.ProcessGroup] | None, + ) -> None: + self._vae = vae + self._label = label + self._group_getter = group_getter + self._orig_decode = vae.decode + + def decode(self, z: torch.Tensor, *args: Any, **kwargs: Any): + device = getattr(z, "device", None) + is_cuda = bool(device is not None and device.type == "cuda") + + if is_cuda: + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.synchronize(device) + t0 = time.perf_counter() + out = self._orig_decode(z, *args, **kwargs) + if is_cuda: + torch.cuda.synchronize(device) + dt_ms = (time.perf_counter() - t0) * 1000 + + dist_rank, dist_world_size = _get_rank_world(self._group_getter) + rank_str = dist_rank if dist_rank is not None else "na" + world_str = dist_world_size if dist_world_size is not None else "na" + + if is_cuda: + peak_alloc_gib = torch.cuda.max_memory_allocated(device) / (1024**3) + peak_reserved_gib = torch.cuda.max_memory_reserved(device) / (1024**3) + logger.debug( + "%s VAE decode profile: rank=%s/%s time_ms=%.3f peak_alloc_gib=%.3f peak_reserved_gib=%.3f", + self._label, + rank_str, + world_str, + dt_ms, + peak_alloc_gib, + peak_reserved_gib, + ) + else: + logger.debug( + "%s VAE decode profile: rank=%s/%s time_ms=%.3f", + self._label, + rank_str, + world_str, + dt_ms, + ) + + return out + + +def maybe_install_vae_decode_profiler( + pipeline: Any, + *, + enabled: bool, + group_getter: Callable[[], dist.ProcessGroup] | None = None, +) -> None: + if not enabled: + return + + vae = getattr(pipeline, "vae", None) + if vae is None or not hasattr(vae, "decode"): + return + + if getattr(vae, "_vllm_vae_decode_profiler_installed", False): + return + + wrapper = VaeDecodeProfiler( + vae, + label=type(pipeline).__name__, + group_getter=group_getter, + ) + + vae._vllm_vae_decode_profiler_installed = True # type: ignore[attr-defined] + vae._vllm_vae_decode_profiler_original_decode = vae.decode # type: ignore[attr-defined] + vae.decode = wrapper.decode # type: ignore[assignment] + From 02dabb006013a278a6c2515f4c95ac39ef93f069 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Fri, 16 Jan 2026 00:24:20 +0800 Subject: [PATCH 12/39] diffusion: remove VAE decode profiling hooks Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/data.py | 11 -- vllm_omni/diffusion/registry.py | 6 - vllm_omni/diffusion/vae/decode_profiler.py | 123 --------------------- 3 files changed, 140 deletions(-) delete mode 100644 vllm_omni/diffusion/vae/decode_profiler.py diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index db92873f134..7b89e71a294 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -370,10 +370,6 @@ class OmniDiffusionConfig: # support multi images input supports_multimodal_inputs: bool = False - # Logging - enable_vae_profiling: bool = False - """Enable lightweight VAE decode profiling logs.""" - log_level: str = "info" # Omni configuration (injected from stage config) @@ -469,13 +465,6 @@ def __post_init__(self): self.max_cpu_loras = 1 elif self.max_cpu_loras < 1: raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") - - if not self.enable_vae_profiling: - # Optional backdoor for quick experimentation without changing code/config. - env_override = os.getenv("VLLM_DIFFUSION_PROFILE_VAE") - if env_override is not None and env_override.strip().lower() not in ("0", "false", "off", "no"): - self.enable_vae_profiling = True - def update_multimodal_support(self) -> None: self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 8c7f6b175d0..5d2765d741c 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -134,7 +134,6 @@ def initialize_model( model.vae.use_tiling = od_config.vae_use_tiling if hasattr(model, "vae") and hasattr(od_config, "parallel_config"): from vllm_omni.diffusion.distributed.parallel_state import get_dit_group - from vllm_omni.diffusion.vae.decode_profiler import maybe_install_vae_decode_profiler from vllm_omni.diffusion.vae.patch_parallelism import maybe_install_vae_patch_parallelism maybe_install_vae_patch_parallelism( @@ -142,11 +141,6 @@ def initialize_model( vae_patch_parallel_size=od_config.parallel_config.vae_patch_parallel_size, group_getter=get_dit_group, ) - maybe_install_vae_decode_profiler( - model, - enabled=od_config.enable_vae_profiling, - group_getter=get_dit_group, - ) # Apply sequence parallelism if enabled # This follows diffusers' pattern where enable_parallelism() is called diff --git a/vllm_omni/diffusion/vae/decode_profiler.py b/vllm_omni/diffusion/vae/decode_profiler.py deleted file mode 100644 index d1088af5ecf..00000000000 --- a/vllm_omni/diffusion/vae/decode_profiler.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from __future__ import annotations - -import time -from collections.abc import Callable -from typing import Any - -import torch -import torch.distributed as dist -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -def _get_rank_world( - group_getter: Callable[[], dist.ProcessGroup] | None, -) -> tuple[int | None, int | None]: - if not dist.is_initialized(): - return None, None - - group: dist.ProcessGroup | None = None - if group_getter is not None: - try: - group = group_getter() - except Exception: - group = None - - try: - if group is None: - return dist.get_rank(), dist.get_world_size() - return dist.get_rank(group), dist.get_world_size(group) - except Exception: - return None, None - - -class VaeDecodeProfiler: - """Lightweight VAE decode profiler wrapper. - - This is meant to be installed as an instance-level override of `vae.decode` - so pipelines don't need model-specific code paths. - """ - - def __init__( - self, - vae: Any, - *, - label: str, - group_getter: Callable[[], dist.ProcessGroup] | None, - ) -> None: - self._vae = vae - self._label = label - self._group_getter = group_getter - self._orig_decode = vae.decode - - def decode(self, z: torch.Tensor, *args: Any, **kwargs: Any): - device = getattr(z, "device", None) - is_cuda = bool(device is not None and device.type == "cuda") - - if is_cuda: - torch.cuda.reset_peak_memory_stats(device) - torch.cuda.synchronize(device) - t0 = time.perf_counter() - out = self._orig_decode(z, *args, **kwargs) - if is_cuda: - torch.cuda.synchronize(device) - dt_ms = (time.perf_counter() - t0) * 1000 - - dist_rank, dist_world_size = _get_rank_world(self._group_getter) - rank_str = dist_rank if dist_rank is not None else "na" - world_str = dist_world_size if dist_world_size is not None else "na" - - if is_cuda: - peak_alloc_gib = torch.cuda.max_memory_allocated(device) / (1024**3) - peak_reserved_gib = torch.cuda.max_memory_reserved(device) / (1024**3) - logger.debug( - "%s VAE decode profile: rank=%s/%s time_ms=%.3f peak_alloc_gib=%.3f peak_reserved_gib=%.3f", - self._label, - rank_str, - world_str, - dt_ms, - peak_alloc_gib, - peak_reserved_gib, - ) - else: - logger.debug( - "%s VAE decode profile: rank=%s/%s time_ms=%.3f", - self._label, - rank_str, - world_str, - dt_ms, - ) - - return out - - -def maybe_install_vae_decode_profiler( - pipeline: Any, - *, - enabled: bool, - group_getter: Callable[[], dist.ProcessGroup] | None = None, -) -> None: - if not enabled: - return - - vae = getattr(pipeline, "vae", None) - if vae is None or not hasattr(vae, "decode"): - return - - if getattr(vae, "_vllm_vae_decode_profiler_installed", False): - return - - wrapper = VaeDecodeProfiler( - vae, - label=type(pipeline).__name__, - group_getter=group_getter, - ) - - vae._vllm_vae_decode_profiler_installed = True # type: ignore[attr-defined] - vae._vllm_vae_decode_profiler_original_decode = vae.decode # type: ignore[attr-defined] - vae.decode = wrapper.decode # type: ignore[assignment] - From d730c04c9826d6e248ecdcff71738507bd75b789 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Fri, 16 Jan 2026 00:34:15 +0800 Subject: [PATCH 13/39] diffusion: allowlist VAE patch parallel install Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../offline_inference/test_zimage_tensor_parallel.py | 2 -- vllm_omni/diffusion/registry.py | 11 ++++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py index 2907588653c..a78be23daca 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -216,8 +216,6 @@ def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): enforce_eager = _get_enforce_eager_for_cuda() # Use a larger image to ensure there are multiple VAE tiles. - # For Z-Image-Turbo, VAE tiling kicks in when latent_h/latent_w > 128. - # 1152x1152 -> latent 144x144. height = 1152 width = 1152 num_inference_steps = 2 diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 5d2765d741c..f7909dd119f 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -103,6 +103,11 @@ } ) +_VAE_PATCH_PARALLEL_ALLOWLIST = { + # Only enable for models we have validated end-to-end. + "ZImagePipeline", +} + def initialize_model( od_config: OmniDiffusionConfig, @@ -132,7 +137,11 @@ def initialize_model( model.vae.use_slicing = od_config.vae_use_slicing if hasattr(model.vae, "use_tiling"): model.vae.use_tiling = od_config.vae_use_tiling - if hasattr(model, "vae") and hasattr(od_config, "parallel_config"): + if ( + hasattr(model, "vae") + and hasattr(od_config, "parallel_config") + and od_config.model_class_name in _VAE_PATCH_PARALLEL_ALLOWLIST + ): from vllm_omni.diffusion.distributed.parallel_state import get_dit_group from vllm_omni.diffusion.vae.patch_parallelism import maybe_install_vae_patch_parallelism From b63544ea206f900b35e625b486168e79f7e783c2 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Fri, 16 Jan 2026 15:48:36 +0800 Subject: [PATCH 14/39] diffusion: drop env override for VAE patch parallel size Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/data.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 7b89e71a294..4e9fbfba417 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -69,20 +69,6 @@ def __post_init__(self) -> None: if self.sequence_parallel_size is None: self.sequence_parallel_size = self.ulysses_degree * self.ring_degree - env_override = os.environ.get("VLLM_DIFFUSION_VAE_PATCH_PARALLEL_SIZE") - if env_override is not None and self.vae_patch_parallel_size == 1: - try: - env_value = int(env_override) - if env_value > 0: - self.vae_patch_parallel_size = env_value - else: - logger.warning( - "Ignoring invalid VLLM_DIFFUSION_VAE_PATCH_PARALLEL_SIZE=%r (must be > 0).", - env_override, - ) - except ValueError: - logger.warning("Ignoring invalid VLLM_DIFFUSION_VAE_PATCH_PARALLEL_SIZE=%r.", env_override) - self.world_size = ( self.pipeline_parallel_size * self.data_parallel_size From 432276f31166b5e2fda9d6c53b5238c0303c974c Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Fri, 16 Jan 2026 17:25:47 +0800 Subject: [PATCH 15/39] diffusion: remove legacy vae_parallel_size group Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../diffusion/distributed/parallel_state.py | 39 +------------------ 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index b249515909d..8363a469523 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -64,7 +64,6 @@ _CFG: GroupCoordinator | None = None _DP: GroupCoordinator | None = None _DIT: GroupCoordinator | None = None -_VAE: GroupCoordinator | None = None def generate_masked_orthogonal_rank_groups( @@ -351,7 +350,7 @@ def is_dp_last_group(): def get_dit_world_size(): - """Return world size for the DiT model (excluding VAE).""" + """Return world size for the DiT model.""" return ( get_data_parallel_world_size() * get_classifier_free_guidance_world_size() @@ -361,22 +360,6 @@ def get_dit_world_size(): ) -# Add VAE getter functions -def get_vae_parallel_group() -> GroupCoordinator: - assert _VAE is not None, "VAE parallel group is not initialized" - return _VAE - - -def get_vae_parallel_world_size(): - """Return world size for the VAE parallel group.""" - return get_vae_parallel_group().world_size - - -def get_vae_parallel_rank(): - """Return my rank for the VAE parallel group.""" - return get_vae_parallel_group().rank_in_group - - # * SET @@ -496,18 +479,6 @@ def get_dit_group(): return _DIT -def init_vae_group( - dit_parallel_size: int, - vae_parallel_size: int, - backend: str, -): - # Initialize VAE group first - global _VAE - assert _VAE is None, "VAE parallel group is already initialized" - vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size)) - _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) - - # adapted from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True): """ @@ -568,7 +539,6 @@ def initialize_model_parallel( ring_degree: int = 1, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, - vae_parallel_size: int = 0, backend: str | None = None, ) -> None: if backend is None: @@ -708,8 +678,6 @@ def initialize_model_parallel( backend=backend, parallel_mode="tensor", ) - if vae_parallel_size > 0: - init_vae_group(dit_parallel_size, vae_parallel_size, backend) init_dit_group(dit_parallel_size, backend) @@ -739,11 +707,6 @@ def destroy_model_parallel(): _PP.destroy() _PP = None - global _VAE - if _VAE: - _VAE.destroy() - _VAE = None - def destroy_distributed_environment(): global _WORLD From 5ee4f9707ad5967e3998d33e6d40924ba2930137 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 20 Jan 2026 18:33:50 +0800 Subject: [PATCH 16/39] diffusion: document and reorganize VAE patch parallelism Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../diffusion/parallelism_acceleration.md | 67 +++++++++++++++---- docs/user_guide/diffusion_acceleration.md | 6 ++ ...or_parallel.py => test_zimage_parallel.py} | 2 + .../diffusion/vae/test_patch_parallelism.py | 2 +- .../vae_patch_parallel.py} | 2 + vllm_omni/diffusion/registry.py | 22 ++++-- 6 files changed, 83 insertions(+), 18 deletions(-) rename tests/e2e/offline_inference/{test_zimage_tensor_parallel.py => test_zimage_parallel.py} (99%) rename vllm_omni/diffusion/{vae/patch_parallelism.py => distributed/vae_patch_parallel.py} (99%) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 09e4651ae23..3710b4279a3 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -14,22 +14,24 @@ The following parallelism methods are currently supported in vLLM-Omni: 4. [Tensor Parallelism](#tensor-parallelism): Tensor parallelism shards model weights across devices. This can reduce per-GPU memory usage. Note that for diffusion models we currently shard the majority of layers within the DiT. +5. [VAE Patch Parallelism](#vae-patch-parallelism): VAE patch parallelism shards VAE decode/encode spatially across ranks. This can reduce the peak memory of VAE decode and (depending on resolution and communication overhead) speed up VAE decode. + The following table shows which models are currently supported by parallelism method: ### ImageGen -| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | -|-------|------------------|------------|---------|--------------|--------------------------| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ❌ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ❌ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | -| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | +| Model | Model Identifier | Ulysses-SP | Ring-Attention | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | +|-------|------------------|------------|----------------|--------------|-----------------|--------------------| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ❌ | ❌ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | ✅ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | ❌ | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | !!! note "TP Limitations for Diffusion Models" @@ -47,7 +49,7 @@ The following table shows which models are currently supported by parallelism me ### VideoGen -| Model | Model Identifier | Ulysses-SP | Ring-SP | Tensor-Parallel | +| Model | Model Identifier | Ulysses-SP | Ring-Attention | Tensor-Parallel | |-------|------------------|------------|---------|--------------------------| | **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ❌ | @@ -74,6 +76,45 @@ outputs = omni.generate( ) ``` +### VAE Patch Parallelism + +VAE patch parallelism distributes the VAE decode/encode workload across multiple ranks by splitting the latent spatially. It is configured via `DiffusionParallelConfig.vae_patch_parallel_size` and can be combined with other parallelism methods (e.g., TP). + +!!! note "Enablement and feature gate" + - VAE patch parallelism is currently **enabled only for validated pipelines** (currently: `Tongyi-MAI/Z-Image-Turbo`). + - Set `vae_use_tiling=True` to enable this feature. (We use `vae_use_tiling` as a safety gate because it indicates the VAE supports diffusers tiling parameters like `tile_latent_min_size` and `tile_overlap_factor`.) + +#### Offline Inference + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + parallel_config=DiffusionParallelConfig( + tensor_parallel_size=2, + vae_patch_parallel_size=2, + ), + vae_use_tiling=True, +) + +outputs = omni.generate( + prompt="a cat reading a book", + num_inference_steps=9, + width=1024, + height=1024, +) +``` + +#### How it works (method selection) + +VAE patch parallelism automatically selects between two internal decode methods based on whether diffusers tiling would kick in: + +- `_distributed_tiled_decode`: Used when the latent spatial size exceeds `vae.tile_latent_min_size` (i.e., diffusers tiled decode). Each rank decodes a subset of tiles; rank0 gathers and runs the same overlap+blend+stitch logic as diffusers. This matches the single-rank diffusers tiled output. + +- `_distributed_patch_decode`: Used when diffusers tiling would not kick in. Each rank decodes a grid patch expanded with a latent-space halo; then rank0 gathers the cropped core patches and stitches them into the full image. This path has no blending and can introduce small numerical differences compared to the non-parallel decode. + ### Sequence Parallelism #### Ulysses-SP diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 42202d8d7ed..f4cee02e52f 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -22,6 +22,10 @@ vLLM-Omni also supports parallelism methods for diffusion models, including: 3. [CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step. +4. [Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism) - shards DiT weights across devices to reduce per-GPU memory usage. + +5. [VAE Patch Parallelism](diffusion/parallelism_acceleration.md#vae-patch-parallelism) - shards VAE decode/encode spatially across ranks to reduce VAE peak memory (and can speed up VAE decode). + ## Quick Comparison ### Cache Methods @@ -200,5 +204,7 @@ For detailed information on each acceleration method: - **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices - **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism)** - Guidance on how to enable TP for diffusion models. - **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. - **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. +- **[VAE Patch Parallelism](diffusion/parallelism_acceleration.md#vae-patch-parallelism)** - Guidance on how to reduce VAE memory via patch/tile parallelism. diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_parallel.py similarity index 99% rename from tests/e2e/offline_inference/test_zimage_tensor_parallel.py rename to tests/e2e/offline_inference/test_zimage_parallel.py index a78be23daca..dc6237ac508 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_parallel.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Z-Image end-to-end tests for diffusion parallelism (TP and VAE patch parallelism).""" + import os import sys import time diff --git a/tests/unit/diffusion/vae/test_patch_parallelism.py b/tests/unit/diffusion/vae/test_patch_parallelism.py index 033885d9f83..b19ac4ab4df 100644 --- a/tests/unit/diffusion/vae/test_patch_parallelism.py +++ b/tests/unit/diffusion/vae/test_patch_parallelism.py @@ -3,7 +3,7 @@ import pytest -from vllm_omni.diffusion.vae import patch_parallelism as pp +from vllm_omni.diffusion.distributed import vae_patch_parallel as pp class _DummyConfig: diff --git a/vllm_omni/diffusion/vae/patch_parallelism.py b/vllm_omni/diffusion/distributed/vae_patch_parallel.py similarity index 99% rename from vllm_omni/diffusion/vae/patch_parallelism.py rename to vllm_omni/diffusion/distributed/vae_patch_parallel.py index 961e8363a33..8e37618f3b4 100644 --- a/vllm_omni/diffusion/vae/patch_parallelism.py +++ b/vllm_omni/diffusion/distributed/vae_patch_parallel.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Distributed VAE patch/tile parallelism utilities.""" + from __future__ import annotations import math diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index f7909dd119f..7b24a7a6518 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -13,6 +13,8 @@ logger = init_logger(__name__) +logger = init_logger(__name__) + _DIFFUSION_MODELS = { # arch:(mod_folder, mod_relname, cls_name) "QwenImagePipeline": ( @@ -137,17 +139,29 @@ def initialize_model( model.vae.use_slicing = od_config.vae_use_slicing if hasattr(model.vae, "use_tiling"): model.vae.use_tiling = od_config.vae_use_tiling + + vae_pp_size = int(getattr(getattr(od_config, "parallel_config", None), "vae_patch_parallel_size", 1)) + if vae_pp_size > 1 and od_config.model_class_name not in _VAE_PATCH_PARALLEL_ALLOWLIST: + logger.warning( + "vae_patch_parallel_size=%d is set but VAE patch parallelism is only enabled for %s; ignoring.", + vae_pp_size, + sorted(_VAE_PATCH_PARALLEL_ALLOWLIST), + ) + if vae_pp_size > 1 and not od_config.vae_use_tiling: + logger.warning("vae_patch_parallel_size=%d requires vae_use_tiling=True; ignoring.", vae_pp_size) + if ( - hasattr(model, "vae") - and hasattr(od_config, "parallel_config") + vae_pp_size > 1 + and hasattr(model, "vae") and od_config.model_class_name in _VAE_PATCH_PARALLEL_ALLOWLIST + and od_config.vae_use_tiling ): from vllm_omni.diffusion.distributed.parallel_state import get_dit_group - from vllm_omni.diffusion.vae.patch_parallelism import maybe_install_vae_patch_parallelism + from vllm_omni.diffusion.distributed.vae_patch_parallel import maybe_install_vae_patch_parallelism maybe_install_vae_patch_parallelism( model, - vae_patch_parallel_size=od_config.parallel_config.vae_patch_parallel_size, + vae_patch_parallel_size=vae_pp_size, group_getter=get_dit_group, ) From 147a69008091a4f7a900047222db06e144ab2279 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 20 Jan 2026 18:37:48 +0800 Subject: [PATCH 17/39] docs: add tensor parallelism quickstart Signed-off-by: dongbo910220 <1275604947@qq.com> --- docs/user_guide/diffusion_acceleration.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index f4cee02e52f..d3bc0ba7c66 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -178,6 +178,25 @@ omni = Omni( outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) ``` +### Using Tensor Parallelism + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + parallel_config=DiffusionParallelConfig(tensor_parallel_size=2), +) + +outputs = omni.generate( + prompt="a cat reading a book", + num_inference_steps=9, + width=512, + height=512, +) +``` + ### Using CFG-Parallel Run image-to-image: From a7f5feb2969e831b95b9f0f3ba48067c96ccd617 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 20 Jan 2026 22:29:19 +0800 Subject: [PATCH 18/39] diffusion: auto-enable VAE tiling for vae pp Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../diffusion/parallelism_acceleration.md | 2 +- vllm_omni/diffusion/registry.py | 23 +++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 3710b4279a3..53e907fe58c 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -82,7 +82,7 @@ VAE patch parallelism distributes the VAE decode/encode workload across multiple !!! note "Enablement and feature gate" - VAE patch parallelism is currently **enabled only for validated pipelines** (currently: `Tongyi-MAI/Z-Image-Turbo`). - - Set `vae_use_tiling=True` to enable this feature. (We use `vae_use_tiling` as a safety gate because it indicates the VAE supports diffusers tiling parameters like `tile_latent_min_size` and `tile_overlap_factor`.) + - If `vae_patch_parallel_size > 1` is set for a validated pipeline, vLLM-Omni will automatically enable `vae_use_tiling` as a safety gate. (We use `vae_use_tiling` because it indicates the VAE supports diffusers tiling parameters like `tile_latent_min_size` and `tile_overlap_factor`.) #### Offline Inference diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 7b24a7a6518..4df7ecfb0be 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -134,11 +134,6 @@ def initialize_model( model_class = DiffusionModelRegistry._try_load_model_cls(od_config.model_class_name) if model_class is not None: model = model_class(od_config=od_config) - # Configure VAE memory optimization settings from config - if hasattr(model.vae, "use_slicing"): - model.vae.use_slicing = od_config.vae_use_slicing - if hasattr(model.vae, "use_tiling"): - model.vae.use_tiling = od_config.vae_use_tiling vae_pp_size = int(getattr(getattr(od_config, "parallel_config", None), "vae_patch_parallel_size", 1)) if vae_pp_size > 1 and od_config.model_class_name not in _VAE_PATCH_PARALLEL_ALLOWLIST: @@ -147,8 +142,22 @@ def initialize_model( vae_pp_size, sorted(_VAE_PATCH_PARALLEL_ALLOWLIST), ) - if vae_pp_size > 1 and not od_config.vae_use_tiling: - logger.warning("vae_patch_parallel_size=%d requires vae_use_tiling=True; ignoring.", vae_pp_size) + if ( + vae_pp_size > 1 + and od_config.model_class_name in _VAE_PATCH_PARALLEL_ALLOWLIST + and not od_config.vae_use_tiling + ): + logger.info( + "vae_patch_parallel_size=%d requires vae_use_tiling; automatically enabling it.", + vae_pp_size, + ) + od_config.vae_use_tiling = True + + # Configure VAE memory optimization settings from config + if hasattr(model.vae, "use_slicing"): + model.vae.use_slicing = od_config.vae_use_slicing + if hasattr(model.vae, "use_tiling"): + model.vae.use_tiling = od_config.vae_use_tiling if ( vae_pp_size > 1 From de1c63455b702465d73d1e630d87f68eb978a1f4 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 20 Jan 2026 23:46:19 +0800 Subject: [PATCH 19/39] docs: expand diffusion acceleration support table Signed-off-by: dongbo910220 <1275604947@qq.com> --- docs/user_guide/diffusion_acceleration.md | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index d3bc0ba7c66..a0123ba1099 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -41,19 +41,19 @@ The following table shows which models are currently supported by each accelerat ### ImageGen -| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | -|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | +|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:|:---------------:|:------------------:| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ (TP=2 only) | ✅ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ### VideoGen From 31bfb4328c4aaade9ecedf9f719f48bd92af3c49 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 20 Jan 2026 23:53:32 +0800 Subject: [PATCH 20/39] docs: add VAE patch parallel example and align tables Signed-off-by: dongbo910220 <1275604947@qq.com> --- docs/user_guide/diffusion_acceleration.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index a0123ba1099..b8194e16471 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -197,6 +197,28 @@ outputs = omni.generate( ) ``` +### Using VAE Patch Parallelism + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + parallel_config=DiffusionParallelConfig( + tensor_parallel_size=2, + vae_patch_parallel_size=2, + ), +) + +outputs = omni.generate( + prompt="a cat reading a book", + num_inference_steps=9, + width=1024, + height=1024, +) +``` + ### Using CFG-Parallel Run image-to-image: From 1536fffe312883219d367c9750c674f9c550f59a Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 21 Jan 2026 00:17:59 +0800 Subject: [PATCH 21/39] tests: relocate VAE patch parallel unit tests Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../test_vae_patch_parallel.py} | 2 ++ 1 file changed, 2 insertions(+) rename tests/unit/diffusion/{vae/test_patch_parallelism.py => distributed/test_vae_patch_parallel.py} (98%) diff --git a/tests/unit/diffusion/vae/test_patch_parallelism.py b/tests/unit/diffusion/distributed/test_vae_patch_parallel.py similarity index 98% rename from tests/unit/diffusion/vae/test_patch_parallelism.py rename to tests/unit/diffusion/distributed/test_vae_patch_parallel.py index b19ac4ab4df..65b53e409eb 100644 --- a/tests/unit/diffusion/vae/test_patch_parallelism.py +++ b/tests/unit/diffusion/distributed/test_vae_patch_parallel.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for VAE patch/tile parallelism helpers (CPU-only).""" + import pytest from vllm_omni.diffusion.distributed import vae_patch_parallel as pp From 9ae119f7a6dc71ac0b2ee937fc505b12cd9fe609 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 21 Jan 2026 02:34:25 +0800 Subject: [PATCH 22/39] tests: align Z-Image TP e2e size with upstream Signed-off-by: dongbo910220 <1275604947@qq.com> --- tests/e2e/offline_inference/test_zimage_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/offline_inference/test_zimage_parallel.py b/tests/e2e/offline_inference/test_zimage_parallel.py index dc6237ac508..6d2638a6904 100644 --- a/tests/e2e/offline_inference/test_zimage_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_parallel.py @@ -155,8 +155,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): enforce_eager = _get_enforce_eager_for_cuda() - height = 256 - width = 256 + height = 512 + width = 512 num_inference_steps = 2 seed = 42 From 20e1d3a6a022a7b13bd7e8757e2552e9b4bd97f0 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 21 Jan 2026 16:18:27 +0800 Subject: [PATCH 23/39] cleanup: move VAE pp unit test and remove empty vae pkg Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../{unit => }/diffusion/distributed/test_vae_patch_parallel.py | 0 vllm_omni/diffusion/vae/__init__.py | 2 -- 2 files changed, 2 deletions(-) rename tests/{unit => }/diffusion/distributed/test_vae_patch_parallel.py (100%) delete mode 100644 vllm_omni/diffusion/vae/__init__.py diff --git a/tests/unit/diffusion/distributed/test_vae_patch_parallel.py b/tests/diffusion/distributed/test_vae_patch_parallel.py similarity index 100% rename from tests/unit/diffusion/distributed/test_vae_patch_parallel.py rename to tests/diffusion/distributed/test_vae_patch_parallel.py diff --git a/vllm_omni/diffusion/vae/__init__.py b/vllm_omni/diffusion/vae/__init__.py deleted file mode 100644 index 208f01a7cb5..00000000000 --- a/vllm_omni/diffusion/vae/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project From 36869f6f4cfe71fb893cb78ea96d702ff43202a9 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 28 Jan 2026 12:46:27 +0800 Subject: [PATCH 24/39] style: apply ruff format Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 4e9fbfba417..7888c170a2a 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -451,6 +451,7 @@ def __post_init__(self): self.max_cpu_loras = 1 elif self.max_cpu_loras < 1: raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") + def update_multimodal_support(self) -> None: self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} From 69eeb79c2561c1dbe92932b660fd45d2f72fbd72 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Thu, 29 Jan 2026 01:32:56 +0800 Subject: [PATCH 25/39] tests: fix zimage platform checks Signed-off-by: dongbo910220 <1275604947@qq.com> --- tests/e2e/offline_inference/test_zimage_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/offline_inference/test_zimage_parallel.py b/tests/e2e/offline_inference/test_zimage_parallel.py index dd564ed2b18..701560a5c8d 100644 --- a/tests/e2e/offline_inference/test_zimage_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_parallel.py @@ -214,7 +214,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): @pytest.mark.integration def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): - if is_npu() or is_rocm(): + if current_omni_platform.is_npu() or current_omni_platform.is_rocm(): pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA for now.") if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.") From 802da55a85a2cd95e06cb88eeb00e4febf95afa0 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 4 Feb 2026 16:15:19 +0800 Subject: [PATCH 26/39] fix: vllm 0.14 multimodal import Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/engine/input_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index a1e467a88ac..70858d035a7 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -11,7 +11,7 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict -from vllm.multimodal.processing.context import set_request_id +from vllm.multimodal.processing import set_request_id from vllm.multimodal.utils import argsort_mm_positions from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams From 2ee0c12933ff386737b27f6ce4848987fb5d99bd Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 4 Feb 2026 21:13:13 +0800 Subject: [PATCH 27/39] refactor: clarify VAE patch parallel naming Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../distributed/test_vae_patch_parallel.py | 48 ++++++++++--------- .../distributed/vae_patch_parallel.py | 6 +-- vllm_omni/diffusion/registry.py | 4 +- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/tests/diffusion/distributed/test_vae_patch_parallel.py b/tests/diffusion/distributed/test_vae_patch_parallel.py index 65b53e409eb..55c575e8400 100644 --- a/tests/diffusion/distributed/test_vae_patch_parallel.py +++ b/tests/diffusion/distributed/test_vae_patch_parallel.py @@ -5,7 +5,7 @@ import pytest -from vllm_omni.diffusion.distributed import vae_patch_parallel as pp +from vllm_omni.diffusion.distributed import vae_patch_parallel as vae_patch_parallel class _DummyConfig: @@ -23,16 +23,16 @@ def __init__(self, *, config=None, **attrs): def test_get_vae_spatial_scale_factor_uses_block_out_channels_len_minus_1(): vae = _DummyVae(config=_DummyConfig(block_out_channels=[128, 256, 512, 512])) - assert pp._get_vae_spatial_scale_factor(vae) == 8 + assert vae_patch_parallel._get_vae_spatial_scale_factor(vae) == 8 vae = _DummyVae(config=_DummyConfig(block_out_channels=[1, 2, 3, 4, 5])) - assert pp._get_vae_spatial_scale_factor(vae) == 16 + assert vae_patch_parallel._get_vae_spatial_scale_factor(vae) == 16 def test_get_vae_spatial_scale_factor_defaults_to_8_on_missing_or_empty(): - assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig())) == 8 - assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig(block_out_channels=[]))) == 8 - assert pp._get_vae_spatial_scale_factor(_DummyVae(config=None)) == 8 + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig())) == 8 + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig(block_out_channels=[]))) == 8 + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=None)) == 8 def test_get_vae_spatial_scale_factor_defaults_to_8_on_exception(): @@ -41,7 +41,7 @@ class _BrokenConfig: def block_out_channels(self): raise RuntimeError("boom") - assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_BrokenConfig())) == 8 + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_BrokenConfig())) == 8 @pytest.mark.parametrize( @@ -59,48 +59,52 @@ def block_out_channels(self): ], ) def test_factor_pp_grid(pp_size: int, expected: tuple[int, int]): - assert pp._factor_pp_grid(pp_size) == expected + assert vae_patch_parallel._factor_pp_grid(pp_size) == expected def test_get_world_rank_pp_size(monkeypatch): - monkeypatch.setattr(pp.dist, "get_world_size", lambda _: 8) - monkeypatch.setattr(pp.dist, "get_rank", lambda _: 3) + monkeypatch.setattr(vae_patch_parallel.dist, "get_world_size", lambda _: 8) + monkeypatch.setattr(vae_patch_parallel.dist, "get_rank", lambda _: 3) - world_size, rank, pp_size = pp._get_world_rank_pp_size(object(), 4) + world_size, rank, pp_size = vae_patch_parallel._get_world_rank_pp_size(object(), 4) assert (world_size, rank, pp_size) == (8, 3, 4) - world_size, rank, pp_size = pp._get_world_rank_pp_size(object(), 16) + world_size, rank, pp_size = vae_patch_parallel._get_world_rank_pp_size(object(), 16) assert (world_size, rank, pp_size) == (8, 3, 8) def test_get_vae_out_channels_defaults_to_3(): - assert pp._get_vae_out_channels(_DummyVae(config=None)) == 3 - assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig())) == 3 + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=None)) == 3 + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig())) == 3 def test_get_vae_out_channels_reads_config(): - assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels=4))) == 4 - assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels="5"))) == 5 + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels=4))) == 4 + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels="5"))) == 5 def test_get_vae_tile_params_returns_none_if_missing(): - assert pp._get_vae_tile_params(_DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25)) is None - assert pp._get_vae_tile_params(_DummyVae(tile_latent_min_size=128, tile_overlap_factor=None)) is None + assert ( + vae_patch_parallel._get_vae_tile_params(_DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25)) is None + ) + assert ( + vae_patch_parallel._get_vae_tile_params(_DummyVae(tile_latent_min_size=128, tile_overlap_factor=None)) is None + ) def test_get_vae_tile_params_parses_types(): vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25") - assert pp._get_vae_tile_params(vae) == (128, 0.25) + assert vae_patch_parallel._get_vae_tile_params(vae) == (128, 0.25) def test_get_vae_tiling_params_returns_none_if_missing(): vae = _DummyVae(tile_latent_min_size=128, tile_overlap_factor=0.25, tile_sample_min_size=None) - assert pp._get_vae_tiling_params(vae) is None + assert vae_patch_parallel._get_vae_tiling_params(vae) is None vae = _DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25, tile_sample_min_size=1024) - assert pp._get_vae_tiling_params(vae) is None + assert vae_patch_parallel._get_vae_tiling_params(vae) is None def test_get_vae_tiling_params_parses_types(): vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25", tile_sample_min_size="1024") - assert pp._get_vae_tiling_params(vae) == (128, 0.25, 1024) + assert vae_patch_parallel._get_vae_tiling_params(vae) == (128, 0.25, 1024) diff --git a/vllm_omni/diffusion/distributed/vae_patch_parallel.py b/vllm_omni/diffusion/distributed/vae_patch_parallel.py index 8e37618f3b4..1a3edde3ceb 100644 --- a/vllm_omni/diffusion/distributed/vae_patch_parallel.py +++ b/vllm_omni/diffusion/distributed/vae_patch_parallel.py @@ -348,7 +348,7 @@ def _distributed_patch_decode( class VaePatchParallelism: """Patch/tile-parallel VAE decode wrapper. - This is meant to be installed as an instance-level override of `vae.decode` + This is meant to wrap `vae.decode` as an instance-level override so pipelines don't need model-specific code paths. """ @@ -443,13 +443,13 @@ def decode(self, z: torch.Tensor, return_dict: bool = True, *args: Any, **kwargs return DecoderOutput(sample=decoded) -def maybe_install_vae_patch_parallelism( +def maybe_wrap_vae_decode_with_patch_parallelism( pipeline: Any, *, vae_patch_parallel_size: int, group_getter: Callable[[], dist.ProcessGroup], ) -> None: - """Install patch-parallel VAE decode wrapper onto a diffusers-style pipeline.""" + """Wrap a diffusers-style pipeline's `vae.decode` with patch/tile parallel decode.""" if vae_patch_parallel_size <= 1: return diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index e7f8ff7fe99..a1822425a4f 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -171,9 +171,9 @@ def initialize_model( and od_config.vae_use_tiling ): from vllm_omni.diffusion.distributed.parallel_state import get_dit_group - from vllm_omni.diffusion.distributed.vae_patch_parallel import maybe_install_vae_patch_parallelism + from vllm_omni.diffusion.distributed.vae_patch_parallel import maybe_wrap_vae_decode_with_patch_parallelism - maybe_install_vae_patch_parallelism( + maybe_wrap_vae_decode_with_patch_parallelism( model, vae_patch_parallel_size=vae_pp_size, group_getter=get_dit_group, From d92353623f5b01e3b04dfbc4a8b4b0d9f997b4d6 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Wed, 4 Feb 2026 22:13:41 +0800 Subject: [PATCH 28/39] test: always enforce eager in zimage e2e Signed-off-by: dongbo910220 <1275604947@qq.com> --- tests/e2e/offline_inference/test_zimage_parallel.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/e2e/offline_inference/test_zimage_parallel.py b/tests/e2e/offline_inference/test_zimage_parallel.py index 701560a5c8d..4f0dfd551c3 100644 --- a/tests/e2e/offline_inference/test_zimage_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_parallel.py @@ -70,11 +70,6 @@ def _extract_single_image(outputs) -> Image.Image: return images[0] -def _get_enforce_eager_for_cuda() -> bool: - cc_major, _cc_minor = torch.cuda.get_device_capability(0) - return cc_major < 8 - - def _run_zimage_generate( *, tp_size: int, @@ -157,7 +152,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.") - enforce_eager = _get_enforce_eager_for_cuda() + enforce_eager = True height = 512 width = 512 @@ -219,7 +214,7 @@ def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.") - enforce_eager = _get_enforce_eager_for_cuda() + enforce_eager = True # Use a larger image to ensure there are multiple VAE tiles. height = 1152 From a8f102455a07d21d3ea774527e06140fb0a88946 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Thu, 5 Feb 2026 14:39:42 +0800 Subject: [PATCH 29/39] refactor: simplify vae patch parallel size access Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index a1822425a4f..0aa7d903c4c 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -140,7 +140,7 @@ def initialize_model( if model_class is not None: model = model_class(od_config=od_config) - vae_pp_size = int(getattr(getattr(od_config, "parallel_config", None), "vae_patch_parallel_size", 1)) + vae_pp_size = od_config.parallel_config.vae_patch_parallel_size if vae_pp_size > 1 and od_config.model_class_name not in _VAE_PATCH_PARALLEL_ALLOWLIST: logger.warning( "vae_patch_parallel_size=%d is set but VAE patch parallelism is only enabled for %s; ignoring.", From 7f2e3ef5995508a41db45a538a8019f5a1bdad65 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Thu, 5 Feb 2026 14:56:36 +0800 Subject: [PATCH 30/39] test: rename zimage parallel e2e file Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../{test_zimage_parallel.py => test_zimage_parallelism.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/e2e/offline_inference/{test_zimage_parallel.py => test_zimage_parallelism.py} (100%) diff --git a/tests/e2e/offline_inference/test_zimage_parallel.py b/tests/e2e/offline_inference/test_zimage_parallelism.py similarity index 100% rename from tests/e2e/offline_inference/test_zimage_parallel.py rename to tests/e2e/offline_inference/test_zimage_parallelism.py From b954d2d6f609f32fdd18e2d8cc3acbf3c6649597 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Thu, 5 Feb 2026 15:41:21 +0800 Subject: [PATCH 31/39] test: clarify zimage parallelism e2e coverage Signed-off-by: dongbo910220 <1275604947@qq.com> --- tests/e2e/offline_inference/test_zimage_parallelism.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index 4f0dfd551c3..11a23369b6b 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -1,7 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Z-Image end-to-end tests for diffusion parallelism (TP and VAE patch parallelism).""" +"""Z-Image end-to-end tests for diffusion parallelism. + +This file currently covers: +- DiT tensor parallelism (TP=2) vs TP=1. +- VAE patch parallelism (vae_patch_parallel_size=2) vs baseline on TP=2. + +Note: CUDA-only (>=2 GPUs). We use `enforce_eager=True` for stability. +""" import os import sys From cdba974902b8e9d0e87c8b4a78bc4a6b325ed4f8 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Fri, 6 Feb 2026 16:54:16 +0800 Subject: [PATCH 32/39] test: enable compile for zimage e2e Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../test_zimage_parallelism.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index 11a23369b6b..c20306f9f62 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -7,7 +7,13 @@ - DiT tensor parallelism (TP=2) vs TP=1. - VAE patch parallelism (vae_patch_parallel_size=2) vs baseline on TP=2. -Note: CUDA-only (>=2 GPUs). We use `enforce_eager=True` for stability. +Note: CUDA-only (>=2 GPUs). We use `enforce_eager=False` (default) to enable +`torch.compile` on supported GPUs. On pre-Ampere GPUs (e.g., V100), we force +eager mode because `torch.compile` does not support bfloat16 compilation there. + +For stability, latency is measured after a warmup output (excluding one-time +compilation/caching overhead). Peak memory is measured across the whole +generate run. """ import os @@ -77,6 +83,13 @@ def _extract_single_image(outputs) -> Image.Image: return images[0] +def _should_force_eager_for_compile() -> bool: + # The diffusion pipeline defaults to bfloat16 weights. Torch inductor does + # not support bfloat16 compilation on pre-Ampere GPUs. + major, _minor = torch.cuda.get_device_capability() + return major < 8 + + def _run_zimage_generate( *, tp_size: int, @@ -128,6 +141,7 @@ def _run_zimage_generate( ) warmup_output = next(gen) + t_prev = time.perf_counter() per_request_times_s: list[float] = [] last_output = warmup_output @@ -159,7 +173,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.") - enforce_eager = True + enforce_eager = _should_force_eager_for_compile() height = 512 width = 512 @@ -221,7 +235,7 @@ def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.") - enforce_eager = True + enforce_eager = _should_force_eager_for_compile() # Use a larger image to ensure there are multiple VAE tiles. height = 1152 From d7e55fb6cc4a5b6c8dc9220709a93f457a828e31 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Fri, 6 Feb 2026 17:51:27 +0800 Subject: [PATCH 33/39] test: trim zimage e2e docstring Signed-off-by: dongbo910220 <1275604947@qq.com> --- tests/e2e/offline_inference/test_zimage_parallelism.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index c20306f9f62..ca5ab97f15b 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -10,10 +10,6 @@ Note: CUDA-only (>=2 GPUs). We use `enforce_eager=False` (default) to enable `torch.compile` on supported GPUs. On pre-Ampere GPUs (e.g., V100), we force eager mode because `torch.compile` does not support bfloat16 compilation there. - -For stability, latency is measured after a warmup output (excluding one-time -compilation/caching overhead). Peak memory is measured across the whole -generate run. """ import os From 5541d132612527c9fd2ec1bd91e9d4d019534c6d Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Fri, 6 Feb 2026 19:35:18 +0800 Subject: [PATCH 34/39] test: drop eager guard in zimage e2e Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../offline_inference/test_zimage_parallelism.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index ca5ab97f15b..15b2d054daa 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -8,8 +8,7 @@ - VAE patch parallelism (vae_patch_parallel_size=2) vs baseline on TP=2. Note: CUDA-only (>=2 GPUs). We use `enforce_eager=False` (default) to enable -`torch.compile` on supported GPUs. On pre-Ampere GPUs (e.g., V100), we force -eager mode because `torch.compile` does not support bfloat16 compilation there. +`torch.compile`. """ import os @@ -79,13 +78,6 @@ def _extract_single_image(outputs) -> Image.Image: return images[0] -def _should_force_eager_for_compile() -> bool: - # The diffusion pipeline defaults to bfloat16 weights. Torch inductor does - # not support bfloat16 compilation on pre-Ampere GPUs. - major, _minor = torch.cuda.get_device_capability() - return major < 8 - - def _run_zimage_generate( *, tp_size: int, @@ -169,7 +161,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.") - enforce_eager = _should_force_eager_for_compile() + enforce_eager = False height = 512 width = 512 @@ -231,7 +223,7 @@ def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.") - enforce_eager = _should_force_eager_for_compile() + enforce_eager = False # Use a larger image to ensure there are multiple VAE tiles. height = 1152 From e0b747a38d0e80be89e9efcb5f05239a501a7018 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Mon, 9 Feb 2026 09:53:56 +0800 Subject: [PATCH 35/39] fix: remove duplicate logger init Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/diffusion/registry.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 0aa7d903c4c..fe9338e37b7 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -13,8 +13,6 @@ logger = init_logger(__name__) -logger = init_logger(__name__) - _DIFFUSION_MODELS = { # arch:(mod_folder, mod_relname, cls_name) "QwenImagePipeline": ( From 45c4208daab120991092ee3d0387c771bd8a6331 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Mon, 9 Feb 2026 17:31:37 +0800 Subject: [PATCH 36/39] docs/tests: clarify VAE patch parallelism and update zimage e2e - clarify VAE patch decode wording in the guide - add a tiled decode unit test for VAE patch parallelism - keep Z-Image TP + VAE patch e2e in the same file Signed-off-by: dongbo910220 <1275604947@qq.com> --- .../diffusion/parallelism_acceleration.md | 4 +- .../distributed/test_vae_patch_parallel.py | 112 ++++++++++++++++++ 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 11adc7db2b9..052a4f2cecc 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -14,7 +14,7 @@ The following parallelism methods are currently supported in vLLM-Omni: 4. [Tensor Parallelism](#tensor-parallelism): Tensor parallelism shards model weights across devices. This can reduce per-GPU memory usage. Note that for diffusion models we currently shard the majority of layers within the DiT. -5. [VAE Patch Parallelism](#vae-patch-parallelism): VAE patch parallelism shards VAE decode/encode spatially across ranks. This can reduce the peak memory of VAE decode and (depending on resolution and communication overhead) speed up VAE decode. +5. [VAE Patch Parallelism](#vae-patch-parallelism): VAE patch parallelism shards VAE decode spatially across ranks. This can reduce the peak memory of VAE decode and (depending on resolution and communication overhead) speed up VAE decode. The following table shows which models are currently supported by parallelism method: @@ -80,7 +80,7 @@ outputs = omni.generate( ### VAE Patch Parallelism -VAE patch parallelism distributes the VAE decode/encode workload across multiple ranks by splitting the latent spatially. It is configured via `DiffusionParallelConfig.vae_patch_parallel_size` and can be combined with other parallelism methods (e.g., TP). +VAE patch parallelism distributes the VAE decode workload across multiple ranks by splitting the latent spatially. It is configured via `DiffusionParallelConfig.vae_patch_parallel_size` and can be combined with other parallelism methods (e.g., TP). !!! note "Enablement and feature gate" - VAE patch parallelism is currently **enabled only for validated pipelines** (currently: `Tongyi-MAI/Z-Image-Turbo`). diff --git a/tests/diffusion/distributed/test_vae_patch_parallel.py b/tests/diffusion/distributed/test_vae_patch_parallel.py index 55c575e8400..43bbc71d5cf 100644 --- a/tests/diffusion/distributed/test_vae_patch_parallel.py +++ b/tests/diffusion/distributed/test_vae_patch_parallel.py @@ -4,6 +4,7 @@ """Unit tests for VAE patch/tile parallelism helpers (CPU-only).""" import pytest +import torch from vllm_omni.diffusion.distributed import vae_patch_parallel as vae_patch_parallel @@ -108,3 +109,114 @@ def test_get_vae_tiling_params_returns_none_if_missing(): def test_get_vae_tiling_params_parses_types(): vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25", tile_sample_min_size="1024") assert vae_patch_parallel._get_vae_tiling_params(vae) == (128, 0.25, 1024) + + +def test_distributed_tiled_decode_stitches_tiles(monkeypatch): + class _TinyConfig: + def __init__(self): + self.out_channels = 1 + self.use_post_quant_conv = False + + class _TinyVae: + def __init__(self): + self.config = _TinyConfig() + self.tile_latent_min_size = 2 + self.tile_overlap_factor = 0.0 + self.tile_sample_min_size = 2 + + def decoder(self, x: torch.Tensor) -> torch.Tensor: + return x + + def blend_v(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + def _collect_local_tiles( + *, + vae: _TinyVae, + z: torch.Tensor, + rank: int, + pp_size: int, + ) -> tuple[list[torch.Tensor], list[tuple[int, int, int]]]: + tile_latent_min_size = vae.tile_latent_min_size + overlap_size = int(tile_latent_min_size * (1 - vae.tile_overlap_factor)) + h_starts = list(range(0, z.shape[2], overlap_size)) + w_starts = list(range(0, z.shape[3], overlap_size)) + + local_tiles: list[torch.Tensor] = [] + local_meta: list[tuple[int, int, int]] = [] + tile_id = 0 + for i in h_starts: + for j in w_starts: + tile_rank = (tile_id + 1) % pp_size + if tile_rank == rank: + tile = z[:, :, i : i + tile_latent_min_size, j : j + tile_latent_min_size] + decoded = vae.decoder(tile) + local_tiles.append(decoded) + local_meta.append((tile_id, int(decoded.shape[-2]), int(decoded.shape[-1]))) + tile_id += 1 + return local_tiles, local_meta + + vae = _TinyVae() + z = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4) + + rank0_tiles, rank0_meta = _collect_local_tiles(vae=vae, z=z, rank=0, pp_size=2) + rank1_tiles, rank1_meta = _collect_local_tiles(vae=vae, z=z, rank=1, pp_size=2) + max_count = max(len(rank0_tiles), len(rank1_tiles)) + + def _pack_meta_and_tiles( + tiles: list[torch.Tensor], + meta: list[tuple[int, int, int]], + max_count: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + meta_tensor = torch.full((max_count, 3), -1, dtype=torch.int64) + tile_tensor = torch.zeros( + (max_count, z.shape[0], vae.config.out_channels, vae.tile_sample_min_size, vae.tile_sample_min_size), + dtype=z.dtype, + ) + for idx, (tile_id, h, w) in enumerate(meta): + meta_tensor[idx, 0] = tile_id + meta_tensor[idx, 1] = h + meta_tensor[idx, 2] = w + tile_tensor[idx, :, :, :h, :w] = tiles[idx] + return meta_tensor, tile_tensor + + rank1_meta_tensor, rank1_tile_tensor = _pack_meta_and_tiles(rank1_tiles, rank1_meta, max_count) + rank1_count_tensor = torch.tensor([len(rank1_tiles)], dtype=torch.int64) + + def _fake_gather(tensor, gather_list=None, dst=0, group=None): + if gather_list is None: + return + if tensor.ndim == 1 and tensor.numel() == 1: + gather_list[0].copy_(tensor) + gather_list[1].copy_(rank1_count_tensor) + return + if tensor.ndim == 2 and tensor.shape[1] == 3: + gather_list[0].copy_(tensor) + gather_list[1].copy_(rank1_meta_tensor) + return + if tensor.ndim == 5: + gather_list[0].copy_(tensor) + gather_list[1].copy_(rank1_tile_tensor) + return + raise AssertionError("Unexpected gather payload for test.") + + def _fake_broadcast(_tensor, src=0, group=None): + return + + monkeypatch.setattr(vae_patch_parallel.dist, "get_world_size", lambda _group: 2) + monkeypatch.setattr(vae_patch_parallel.dist, "get_rank", lambda _group: 0) + monkeypatch.setattr(vae_patch_parallel.dist, "gather", _fake_gather) + monkeypatch.setattr(vae_patch_parallel.dist, "broadcast", _fake_broadcast) + + output = vae_patch_parallel._distributed_tiled_decode( + vae=vae, + orig_decode=lambda z, return_dict=False: (z,), + z=z, + group=object(), + vae_patch_parallel_size=2, + ) + + assert torch.equal(output, z) From f4e1db714792f648b04de485427495eb54af8cbc Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 10 Feb 2026 16:11:22 +0800 Subject: [PATCH 37/39] ci: fix zimage TP test path and guard set_request_id import Signed-off-by: dongbo910220 <1275604947@qq.com> --- .buildkite/pipeline.yml | 2 +- .buildkite/test-amd.yaml | 2 +- docs/contributing/features/tensor_parallel.md | 2 +- vllm_omni/engine/input_processor.py | 9 ++++++++- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 01897db969d..525869f4270 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -141,7 +141,7 @@ steps: timeout_in_minutes: 20 depends_on: image-build commands: - - pytest -s -v tests/e2e/offline_inference/test_zimage_tensor_parallel.py + - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py -k tensor_parallel agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 773748c1737..6aada531945 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -66,7 +66,7 @@ steps: - export GPU_ARCHS=gfx942 - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v tests/e2e/offline_inference/test_zimage_tensor_parallel.py + - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py -k tensor_parallel - label: "Diffusion GPU Worker Test" timeout_in_minutes: 20 diff --git a/docs/contributing/features/tensor_parallel.md b/docs/contributing/features/tensor_parallel.md index aecc1aedca1..6214dc78303 100644 --- a/docs/contributing/features/tensor_parallel.md +++ b/docs/contributing/features/tensor_parallel.md @@ -264,7 +264,7 @@ Complete examples in the codebase: | **Z-Image** | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` | Standard TP | Full implementation with validation | | **FLUX** | `vllm_omni/diffusion/models/flux/flux_transformer.py` | Dual-stream | Image + text streams | | **Qwen-Image** | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` | Standard TP | With RoPE | -| **TP Tests** | `tests/e2e/offline_inference/test_zimage_tensor_parallel.py` | E2E testing | TP correctness and performance | +| **TP Tests** | `tests/e2e/offline_inference/test_zimage_parallelism.py` | E2E testing | TP correctness and performance | | **Constraint Tests** | `tests/diffusion/models/z_image/test_zimage_tp_constraints.py` | Unit testing | Validation logic | --- diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index 70858d035a7..190769c8602 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -11,7 +11,14 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict -from vllm.multimodal.processing import set_request_id +try: + from vllm.multimodal.processing import set_request_id +except ImportError: # vllm without set_request_id (older releases) + from contextlib import contextmanager + + @contextmanager + def set_request_id(_request_id: str): + yield from vllm.multimodal.utils import argsort_mm_positions from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams From 08e086b43d95994882e8a4544145b564b0b70bca Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 10 Feb 2026 16:23:12 +0800 Subject: [PATCH 38/39] ci: run full zimage parallelism test Signed-off-by: dongbo910220 <1275604947@qq.com> --- .buildkite/pipeline.yml | 2 +- .buildkite/test-amd.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 525869f4270..5e523efb227 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -141,7 +141,7 @@ steps: timeout_in_minutes: 20 depends_on: image-build commands: - - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py -k tensor_parallel + - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 6aada531945..5b4cedd3fd9 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -66,7 +66,7 @@ steps: - export GPU_ARCHS=gfx942 - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py -k tensor_parallel + - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py - label: "Diffusion GPU Worker Test" timeout_in_minutes: 20 From 6a95bfbafa5e9a83ee814318d632e5e1b2ccb5d9 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Tue, 10 Feb 2026 16:26:30 +0800 Subject: [PATCH 39/39] style: apply pre-commit formatting Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm_omni/engine/input_processor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index 190769c8602..64ffab3a1ab 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -11,6 +11,7 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict + try: from vllm.multimodal.processing import set_request_id except ImportError: # vllm without set_request_id (older releases) @@ -19,6 +20,8 @@ @contextmanager def set_request_id(_request_id: str): yield + + from vllm.multimodal.utils import argsort_mm_positions from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams