diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 01897db969d..5e523efb227 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -141,7 +141,7 @@ steps: timeout_in_minutes: 20 depends_on: image-build commands: - - pytest -s -v tests/e2e/offline_inference/test_zimage_tensor_parallel.py + - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 773748c1737..5b4cedd3fd9 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -66,7 +66,7 @@ steps: - export GPU_ARCHS=gfx942 - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v tests/e2e/offline_inference/test_zimage_tensor_parallel.py + - pytest -s -v tests/e2e/offline_inference/test_zimage_parallelism.py - label: "Diffusion GPU Worker Test" timeout_in_minutes: 20 diff --git a/docs/contributing/features/tensor_parallel.md b/docs/contributing/features/tensor_parallel.md index aecc1aedca1..6214dc78303 100644 --- a/docs/contributing/features/tensor_parallel.md +++ b/docs/contributing/features/tensor_parallel.md @@ -264,7 +264,7 @@ Complete examples in the codebase: | **Z-Image** | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` | Standard TP | Full implementation with validation | | **FLUX** | `vllm_omni/diffusion/models/flux/flux_transformer.py` | Dual-stream | Image + text streams | | **Qwen-Image** | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` | Standard TP | With RoPE | -| **TP Tests** | `tests/e2e/offline_inference/test_zimage_tensor_parallel.py` | E2E testing | TP correctness and performance | +| **TP Tests** | `tests/e2e/offline_inference/test_zimage_parallelism.py` | E2E testing | TP correctness and performance | | **Constraint Tests** | `tests/diffusion/models/z_image/test_zimage_tp_constraints.py` | Unit testing | Validation logic | --- diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index e2dae5c84a4..7b6e9390e48 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -14,23 +14,25 @@ The following parallelism methods are currently supported in vLLM-Omni: 4. [Tensor Parallelism](#tensor-parallelism): Tensor parallelism shards model weights across devices. This can reduce per-GPU memory usage. Note that for diffusion models we currently shard the majority of layers within the DiT. +5. [VAE Patch Parallelism](#vae-patch-parallelism): VAE patch parallelism shards VAE decode spatially across ranks. This can reduce the peak memory of VAE decode and (depending on resolution and communication overhead) speed up VAE decode. + The following table shows which models are currently supported by parallelism method: ### ImageGen -| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | -|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ✅ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ✅ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | -| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | -| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | +| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | +|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:|:------------------:| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ✅ | ❌ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ✅ | ❌ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | ✅ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | ❌ | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | +| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | !!! note "TP Limitations for Diffusion Models" We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. @@ -47,7 +49,7 @@ The following table shows which models are currently supported by parallelism me ### VideoGen -| Model | Model Identifier | Ulysses-SP | Ring-SP | Tensor-Parallel | +| Model | Model Identifier | Ulysses-SP | Ring-Attention | Tensor-Parallel | |-------|------------------|------------|---------|--------------------------| | **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ✅ | @@ -76,6 +78,45 @@ outputs = omni.generate( ) ``` +### VAE Patch Parallelism + +VAE patch parallelism distributes the VAE decode workload across multiple ranks by splitting the latent spatially. It is configured via `DiffusionParallelConfig.vae_patch_parallel_size` and can be combined with other parallelism methods (e.g., TP). + +!!! note "Enablement and feature gate" + - VAE patch parallelism is currently **enabled only for validated pipelines** (currently: `Tongyi-MAI/Z-Image-Turbo`). + - If `vae_patch_parallel_size > 1` is set for a validated pipeline, vLLM-Omni will automatically enable `vae_use_tiling` as a safety gate. (We use `vae_use_tiling` because it indicates the VAE supports diffusers tiling parameters like `tile_latent_min_size` and `tile_overlap_factor`.) + +#### Offline Inference + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + parallel_config=DiffusionParallelConfig( + tensor_parallel_size=2, + vae_patch_parallel_size=2, + ), + vae_use_tiling=True, +) + +outputs = omni.generate( + prompt="a cat reading a book", + num_inference_steps=9, + width=1024, + height=1024, +) +``` + +#### How it works (method selection) + +VAE patch parallelism automatically selects between two internal decode methods based on whether diffusers tiling would kick in: + +- `_distributed_tiled_decode`: Used when the latent spatial size exceeds `vae.tile_latent_min_size` (i.e., diffusers tiled decode). Each rank decodes a subset of tiles; rank0 gathers and runs the same overlap+blend+stitch logic as diffusers. This matches the single-rank diffusers tiled output. + +- `_distributed_patch_decode`: Used when diffusers tiling would not kick in. Each rank decodes a grid patch expanded with a latent-space halo; then rank0 gathers the cropped core patches and stitches them into the full image. This path has no blending and can introduce small numerical differences compared to the non-parallel decode. + ### Sequence Parallelism #### Ulysses-SP diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index ba122a1cb7a..7856ea9606f 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -22,6 +22,10 @@ vLLM-Omni also supports parallelism methods for diffusion models, including: 3. [CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step. +4. [Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism) - shards DiT weights across devices to reduce per-GPU memory usage. + +5. [VAE Patch Parallelism](diffusion/parallelism_acceleration.md#vae-patch-parallelism) - shards VAE decode/encode spatially across ranks to reduce VAE peak memory (and can speed up VAE decode). + ## Quick Comparison ### Cache Methods @@ -37,20 +41,20 @@ The following table shows which models are currently supported by each accelerat ### ImageGen -| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | -|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | -| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | +|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:------------:|:---------------:|:------------------:| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ (TP=2 only) | ✅ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ### VideoGen @@ -201,6 +205,47 @@ outputs = omni.generate( ) ``` +### Using Tensor Parallelism + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + parallel_config=DiffusionParallelConfig(tensor_parallel_size=2), +) + +outputs = omni.generate( + prompt="a cat reading a book", + num_inference_steps=9, + width=512, + height=512, +) +``` + +### Using VAE Patch Parallelism + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + parallel_config=DiffusionParallelConfig( + tensor_parallel_size=2, + vae_patch_parallel_size=2, + ), +) + +outputs = omni.generate( + prompt="a cat reading a book", + num_inference_steps=9, + width=1024, + height=1024, +) +``` + ### Using CFG-Parallel Run image-to-image: @@ -233,5 +278,7 @@ For detailed information on each acceleration method: - **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices - **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism)** - Guidance on how to enable TP for diffusion models. - **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. - **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. +- **[VAE Patch Parallelism](diffusion/parallelism_acceleration.md#vae-patch-parallelism)** - Guidance on how to reduce VAE memory via patch/tile parallelism. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index a79e5d640d2..4b2817f8dee 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -118,12 +118,6 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of ready layers (blocks) to keep on GPU during generation.", ) - parser.add_argument( - "--tensor_parallel_size", - type=int, - default=1, - help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", - ) parser.add_argument( "--vae_use_slicing", action="store_true", @@ -134,6 +128,18 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable VAE tiling for memory optimization.", ) + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", + ) + parser.add_argument( + "--vae_patch_parallel_size", + type=int, + default=1, + help="Number of ranks used for VAE patch/tile parallelism (decode/encode).", + ) return parser.parse_args() @@ -176,6 +182,7 @@ def main(): ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size, tensor_parallel_size=args.tensor_parallel_size, + vae_patch_parallel_size=args.vae_patch_parallel_size, ) # Check if profiling is requested via environment variable @@ -207,8 +214,10 @@ def main(): print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") print( f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, " - f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, " + f"vae_patch_parallel_size={args.vae_patch_parallel_size}" ) + print(f" CPU offload: {args.enable_cpu_offload}") print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") diff --git a/tests/diffusion/distributed/test_vae_patch_parallel.py b/tests/diffusion/distributed/test_vae_patch_parallel.py new file mode 100644 index 00000000000..43bbc71d5cf --- /dev/null +++ b/tests/diffusion/distributed/test_vae_patch_parallel.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for VAE patch/tile parallelism helpers (CPU-only).""" + +import pytest +import torch + +from vllm_omni.diffusion.distributed import vae_patch_parallel as vae_patch_parallel + + +class _DummyConfig: + def __init__(self, **attrs): + for k, v in attrs.items(): + setattr(self, k, v) + + +class _DummyVae: + def __init__(self, *, config=None, **attrs): + self.config = config + for k, v in attrs.items(): + setattr(self, k, v) + + +def test_get_vae_spatial_scale_factor_uses_block_out_channels_len_minus_1(): + vae = _DummyVae(config=_DummyConfig(block_out_channels=[128, 256, 512, 512])) + assert vae_patch_parallel._get_vae_spatial_scale_factor(vae) == 8 + + vae = _DummyVae(config=_DummyConfig(block_out_channels=[1, 2, 3, 4, 5])) + assert vae_patch_parallel._get_vae_spatial_scale_factor(vae) == 16 + + +def test_get_vae_spatial_scale_factor_defaults_to_8_on_missing_or_empty(): + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig())) == 8 + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig(block_out_channels=[]))) == 8 + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=None)) == 8 + + +def test_get_vae_spatial_scale_factor_defaults_to_8_on_exception(): + class _BrokenConfig: + @property + def block_out_channels(self): + raise RuntimeError("boom") + + assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_BrokenConfig())) == 8 + + +@pytest.mark.parametrize( + ("pp_size", "expected"), + [ + (0, (1, 1)), + (1, (1, 1)), + (2, (1, 2)), + (3, (1, 3)), + (4, (2, 2)), + (6, (2, 3)), + (8, (2, 4)), + (12, (3, 4)), + (16, (4, 4)), + ], +) +def test_factor_pp_grid(pp_size: int, expected: tuple[int, int]): + assert vae_patch_parallel._factor_pp_grid(pp_size) == expected + + +def test_get_world_rank_pp_size(monkeypatch): + monkeypatch.setattr(vae_patch_parallel.dist, "get_world_size", lambda _: 8) + monkeypatch.setattr(vae_patch_parallel.dist, "get_rank", lambda _: 3) + + world_size, rank, pp_size = vae_patch_parallel._get_world_rank_pp_size(object(), 4) + assert (world_size, rank, pp_size) == (8, 3, 4) + + world_size, rank, pp_size = vae_patch_parallel._get_world_rank_pp_size(object(), 16) + assert (world_size, rank, pp_size) == (8, 3, 8) + + +def test_get_vae_out_channels_defaults_to_3(): + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=None)) == 3 + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig())) == 3 + + +def test_get_vae_out_channels_reads_config(): + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels=4))) == 4 + assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels="5"))) == 5 + + +def test_get_vae_tile_params_returns_none_if_missing(): + assert ( + vae_patch_parallel._get_vae_tile_params(_DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25)) is None + ) + assert ( + vae_patch_parallel._get_vae_tile_params(_DummyVae(tile_latent_min_size=128, tile_overlap_factor=None)) is None + ) + + +def test_get_vae_tile_params_parses_types(): + vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25") + assert vae_patch_parallel._get_vae_tile_params(vae) == (128, 0.25) + + +def test_get_vae_tiling_params_returns_none_if_missing(): + vae = _DummyVae(tile_latent_min_size=128, tile_overlap_factor=0.25, tile_sample_min_size=None) + assert vae_patch_parallel._get_vae_tiling_params(vae) is None + + vae = _DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25, tile_sample_min_size=1024) + assert vae_patch_parallel._get_vae_tiling_params(vae) is None + + +def test_get_vae_tiling_params_parses_types(): + vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25", tile_sample_min_size="1024") + assert vae_patch_parallel._get_vae_tiling_params(vae) == (128, 0.25, 1024) + + +def test_distributed_tiled_decode_stitches_tiles(monkeypatch): + class _TinyConfig: + def __init__(self): + self.out_channels = 1 + self.use_post_quant_conv = False + + class _TinyVae: + def __init__(self): + self.config = _TinyConfig() + self.tile_latent_min_size = 2 + self.tile_overlap_factor = 0.0 + self.tile_sample_min_size = 2 + + def decoder(self, x: torch.Tensor) -> torch.Tensor: + return x + + def blend_v(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + def _collect_local_tiles( + *, + vae: _TinyVae, + z: torch.Tensor, + rank: int, + pp_size: int, + ) -> tuple[list[torch.Tensor], list[tuple[int, int, int]]]: + tile_latent_min_size = vae.tile_latent_min_size + overlap_size = int(tile_latent_min_size * (1 - vae.tile_overlap_factor)) + h_starts = list(range(0, z.shape[2], overlap_size)) + w_starts = list(range(0, z.shape[3], overlap_size)) + + local_tiles: list[torch.Tensor] = [] + local_meta: list[tuple[int, int, int]] = [] + tile_id = 0 + for i in h_starts: + for j in w_starts: + tile_rank = (tile_id + 1) % pp_size + if tile_rank == rank: + tile = z[:, :, i : i + tile_latent_min_size, j : j + tile_latent_min_size] + decoded = vae.decoder(tile) + local_tiles.append(decoded) + local_meta.append((tile_id, int(decoded.shape[-2]), int(decoded.shape[-1]))) + tile_id += 1 + return local_tiles, local_meta + + vae = _TinyVae() + z = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4) + + rank0_tiles, rank0_meta = _collect_local_tiles(vae=vae, z=z, rank=0, pp_size=2) + rank1_tiles, rank1_meta = _collect_local_tiles(vae=vae, z=z, rank=1, pp_size=2) + max_count = max(len(rank0_tiles), len(rank1_tiles)) + + def _pack_meta_and_tiles( + tiles: list[torch.Tensor], + meta: list[tuple[int, int, int]], + max_count: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + meta_tensor = torch.full((max_count, 3), -1, dtype=torch.int64) + tile_tensor = torch.zeros( + (max_count, z.shape[0], vae.config.out_channels, vae.tile_sample_min_size, vae.tile_sample_min_size), + dtype=z.dtype, + ) + for idx, (tile_id, h, w) in enumerate(meta): + meta_tensor[idx, 0] = tile_id + meta_tensor[idx, 1] = h + meta_tensor[idx, 2] = w + tile_tensor[idx, :, :, :h, :w] = tiles[idx] + return meta_tensor, tile_tensor + + rank1_meta_tensor, rank1_tile_tensor = _pack_meta_and_tiles(rank1_tiles, rank1_meta, max_count) + rank1_count_tensor = torch.tensor([len(rank1_tiles)], dtype=torch.int64) + + def _fake_gather(tensor, gather_list=None, dst=0, group=None): + if gather_list is None: + return + if tensor.ndim == 1 and tensor.numel() == 1: + gather_list[0].copy_(tensor) + gather_list[1].copy_(rank1_count_tensor) + return + if tensor.ndim == 2 and tensor.shape[1] == 3: + gather_list[0].copy_(tensor) + gather_list[1].copy_(rank1_meta_tensor) + return + if tensor.ndim == 5: + gather_list[0].copy_(tensor) + gather_list[1].copy_(rank1_tile_tensor) + return + raise AssertionError("Unexpected gather payload for test.") + + def _fake_broadcast(_tensor, src=0, group=None): + return + + monkeypatch.setattr(vae_patch_parallel.dist, "get_world_size", lambda _group: 2) + monkeypatch.setattr(vae_patch_parallel.dist, "get_rank", lambda _group: 0) + monkeypatch.setattr(vae_patch_parallel.dist, "gather", _fake_gather) + monkeypatch.setattr(vae_patch_parallel.dist, "broadcast", _fake_broadcast) + + output = vae_patch_parallel._distributed_tiled_decode( + vae=vae, + orig_decode=lambda z, return_dict=False: (z,), + z=z, + group=object(), + vae_patch_parallel_size=2, + ) + + assert torch.equal(output, z) diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_parallelism.py similarity index 68% rename from tests/e2e/offline_inference/test_zimage_tensor_parallel.py rename to tests/e2e/offline_inference/test_zimage_parallelism.py index 2d051e5aaf5..79fdb4699ed 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -1,6 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Z-Image end-to-end tests for diffusion parallelism. + +This file currently covers: +- DiT tensor parallelism (TP=2) vs TP=1. +- VAE patch parallelism (vae_patch_parallel_size=2) vs baseline on TP=2. + +Note: CUDA-only (>=2 GPUs). We use `enforce_eager=False` (default) to enable +`torch.compile`. +""" + import os import sys import time @@ -69,16 +79,32 @@ def _extract_single_image(outputs) -> Image.Image: def _run_zimage_generate( - *, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int + *, + tp_size: int, + height: int, + width: int, + num_inference_steps: int, + seed: int, + enforce_eager: bool, + vae_use_tiling: bool = False, + vae_patch_parallel_size: int = 1, + num_requests: int = 4, ) -> tuple[Image.Image, float, float]: + if num_requests < 2: + raise ValueError("num_requests must be >= 2 (1 warmup + >=1 timed)") + torch.cuda.empty_cache() device_index = torch.cuda.current_device() monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) monitor.start() - m = Omni( model=_get_zimage_model(), - parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size), + parallel_config=DiffusionParallelConfig( + tensor_parallel_size=tp_size, + vae_patch_parallel_size=vae_patch_parallel_size, + ), + enforce_eager=enforce_eager, + vae_use_tiling=vae_use_tiling, ) try: # NOTE: Omni closes itself when a generate() call is exhausted. @@ -89,7 +115,6 @@ def _run_zimage_generate( # This also serves as a warmup: the first output may include extra # compilation/caching overhead, while later outputs are closer to # steady-state inference. - num_requests = 4 # 1 warmup + 3 timed gen = m.generate( [PROMPT] * num_requests, OmniDiffusionSamplingParams( @@ -104,6 +129,7 @@ def _run_zimage_generate( ) warmup_output = next(gen) + t_prev = time.perf_counter() per_request_times_s: list[float] = [] last_output = warmup_output @@ -124,6 +150,7 @@ def _run_zimage_generate( return _extract_single_image([last_output]), median_time_s, peak_memory_mb finally: monitor.stop() + m.close() cleanup_dist_env_and_memory() @@ -137,6 +164,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.") + enforce_eager = False + height = 512 width = 512 num_inference_steps = 2 @@ -148,6 +177,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): width=width, num_inference_steps=num_inference_steps, seed=seed, + enforce_eager=enforce_eager, ) tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate( tp_size=2, @@ -155,6 +185,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): width=width, num_inference_steps=num_inference_steps, seed=seed, + enforce_eager=enforce_eager, ) tp1_path = tmp_path / "zimage_tp1.png" @@ -186,3 +217,61 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): assert tp2_peak_mem < tp1_peak_mem, ( f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})" ) + + +@pytest.mark.integration +def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): + if current_omni_platform.is_npu() or current_omni_platform.is_rocm(): + pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA for now.") + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.") + + enforce_eager = False + + # Use a larger image to ensure there are multiple VAE tiles. + height = 1152 + width = 1152 + num_inference_steps = 2 + seed = 42 + + baseline_img, _baseline_time_s, _baseline_peak_mem = _run_zimage_generate( + tp_size=2, + height=height, + width=width, + num_inference_steps=num_inference_steps, + seed=seed, + enforce_eager=enforce_eager, + vae_use_tiling=True, + vae_patch_parallel_size=1, + num_requests=2, + ) + pp2_img, _pp2_time_s, _pp2_peak_mem = _run_zimage_generate( + tp_size=2, + height=height, + width=width, + num_inference_steps=num_inference_steps, + seed=seed, + enforce_eager=enforce_eager, + vae_use_tiling=True, + vae_patch_parallel_size=2, + num_requests=2, + ) + + baseline_path = tmp_path / "zimage_tp2_vae_pp1.png" + pp2_path = tmp_path / "zimage_tp2_vae_pp2.png" + baseline_img.save(baseline_path) + pp2_img.save(pp2_path) + + mean_abs_diff, max_abs_diff = _diff_metrics(baseline_img, pp2_img) + mean_threshold = 5e-3 + max_threshold = 1e-1 + print( + "Z-Image VAE patch parallel image diff stats (TP=2, pp=1 vs pp=2): " + f"mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e}; " + f"thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}; " + f"pp1_img={baseline_path}, pp2_img={pp2_path}" + ) + assert mean_abs_diff <= mean_threshold and max_abs_diff <= max_threshold, ( + f"Image diff exceeded threshold: mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e} " + f"(thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e})" + ) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index f884fb6f177..a4f0ba6fa56 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -45,6 +45,9 @@ class DiffusionParallelConfig: cfg_parallel_size: int = 1 """Number of Classifier Free Guidance (CFG) parallel groups.""" + vae_patch_parallel_size: int = 1 + """Number of ranks used for VAE patch/tile parallelism (decode/encode).""" + @model_validator(mode="after") def _validate_parallel_config(self) -> Self: """Validates the config relationships among the parallel strategies.""" @@ -56,6 +59,7 @@ def _validate_parallel_config(self) -> Self: assert self.ring_degree > 0, "Ring degree must be > 0" assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" assert self.cfg_parallel_size in [1, 2], f"CFG parallel size must be 1 or 2, but got {self.cfg_parallel_size}" + assert self.vae_patch_parallel_size > 0, "VAE patch parallel size must be > 0" assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( "Sequence parallel size must be equal to the product of ulysses degree and ring degree," f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" @@ -65,6 +69,7 @@ def _validate_parallel_config(self) -> Self: def __post_init__(self) -> None: if self.sequence_parallel_size is None: self.sequence_parallel_size = self.ulysses_degree * self.ring_degree + self.world_size = ( self.pipeline_parallel_size * self.data_parallel_size @@ -359,7 +364,6 @@ class OmniDiffusionConfig: # support multi images input supports_multimodal_inputs: bool = False - # Logging log_level: str = "info" # Omni configuration (injected from stage config) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 12da9c43b02..26112af4607 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -57,7 +57,6 @@ _CFG: GroupCoordinator | None = None _DP: GroupCoordinator | None = None _DIT: GroupCoordinator | None = None -_VAE: GroupCoordinator | None = None def generate_masked_orthogonal_rank_groups( @@ -344,7 +343,7 @@ def is_dp_last_group(): def get_dit_world_size(): - """Return world size for the DiT model (excluding VAE).""" + """Return world size for the DiT model.""" return ( get_data_parallel_world_size() * get_classifier_free_guidance_world_size() @@ -354,22 +353,6 @@ def get_dit_world_size(): ) -# Add VAE getter functions -def get_vae_parallel_group() -> GroupCoordinator: - assert _VAE is not None, "VAE parallel group is not initialized" - return _VAE - - -def get_vae_parallel_world_size(): - """Return world size for the VAE parallel group.""" - return get_vae_parallel_group().world_size - - -def get_vae_parallel_rank(): - """Return my rank for the VAE parallel group.""" - return get_vae_parallel_group().rank_in_group - - # * SET @@ -490,18 +473,6 @@ def get_dit_group(): return _DIT -def init_vae_group( - dit_parallel_size: int, - vae_parallel_size: int, - backend: str, -): - # Initialize VAE group first - global _VAE - assert _VAE is None, "VAE parallel group is already initialized" - vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size)) - _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) - - # adapted from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py def set_seq_parallel_pg( sp_ulysses_degree: int, @@ -658,7 +629,6 @@ def initialize_model_parallel( ring_degree: int = 1, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, - vae_parallel_size: int = 0, backend: str | None = None, ) -> None: if backend is None: @@ -802,8 +772,6 @@ def initialize_model_parallel( backend=backend, parallel_mode="tensor", ) - if vae_parallel_size > 0: - init_vae_group(dit_parallel_size, vae_parallel_size, backend) init_dit_group(dit_parallel_size, backend) @@ -833,11 +801,6 @@ def destroy_model_parallel(): _PP.destroy() _PP = None - global _VAE - if _VAE: - _VAE.destroy() - _VAE = None - def destroy_distributed_environment(): global _WORLD diff --git a/vllm_omni/diffusion/distributed/vae_patch_parallel.py b/vllm_omni/diffusion/distributed/vae_patch_parallel.py new file mode 100644 index 00000000000..1a3edde3ceb --- /dev/null +++ b/vllm_omni/diffusion/distributed/vae_patch_parallel.py @@ -0,0 +1,477 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Distributed VAE patch/tile parallelism utilities.""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from typing import Any + +import torch +import torch.distributed as dist +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _get_vae_spatial_scale_factor(vae: Any) -> int: + try: + block_out_channels = getattr(getattr(vae, "config", None), "block_out_channels", None) + if block_out_channels: + return 2 ** (len(block_out_channels) - 1) + except Exception: + pass + return 8 + + +def _factor_pp_grid(pp_size: int) -> tuple[int, int]: + """Pick a (rows, cols) grid whose product equals `pp_size`.""" + if pp_size <= 1: + return (1, 1) + root = int(math.sqrt(pp_size)) + for rows in range(root, 0, -1): + if pp_size % rows == 0: + return (rows, pp_size // rows) + return (1, pp_size) + + +def _get_world_rank_pp_size( + group: dist.ProcessGroup, + vae_patch_parallel_size: int, +) -> tuple[int, int, int]: + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + pp_size = min(int(vae_patch_parallel_size), int(world_size)) + return world_size, rank, pp_size + + +def _get_vae_out_channels(vae: Any) -> int: + return int(getattr(getattr(vae, "config", None), "out_channels", 3)) + + +def _get_vae_tile_params(vae: Any) -> tuple[int, float] | None: + tile_latent_min_size = getattr(vae, "tile_latent_min_size", None) + tile_overlap_factor = getattr(vae, "tile_overlap_factor", None) + if tile_latent_min_size is None or tile_overlap_factor is None: + return None + return int(tile_latent_min_size), float(tile_overlap_factor) + + +def _get_vae_tiling_params(vae: Any) -> tuple[int, float, int] | None: + tile_sample_min_size = getattr(vae, "tile_sample_min_size", None) + tile_params = _get_vae_tile_params(vae) + if tile_params is None or tile_sample_min_size is None: + return None + tile_latent_min_size, tile_overlap_factor = tile_params + return tile_latent_min_size, tile_overlap_factor, int(tile_sample_min_size) + + +def _distributed_tiled_decode( + *, + vae: Any, + orig_decode: Callable[..., Any], + z: torch.Tensor, + group: dist.ProcessGroup, + vae_patch_parallel_size: int, +) -> torch.Tensor: + """Distributed version of diffusers AutoencoderKL.tiled_decode (decode only). + + Each rank decodes a subset of tiles; rank0 gathers all tiles and performs the + original blend + stitch logic. Non-rank0 ranks return an empty tensor; callers + can broadcast the stitched result if needed. + """ + world_size, rank, pp_size = _get_world_rank_pp_size(group, vae_patch_parallel_size) + if pp_size <= 1: + return orig_decode(z, return_dict=False)[0] + + tiling_params = _get_vae_tiling_params(vae) + if tiling_params is None: + return orig_decode(z, return_dict=False)[0] + tile_latent_min_size, tile_overlap_factor, tile_sample_min_size = tiling_params + + overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) + if overlap_size <= 0: + return orig_decode(z, return_dict=False)[0] + + h_starts = list(range(0, z.shape[2], overlap_size)) + w_starts = list(range(0, z.shape[3], overlap_size)) + num_rows = len(h_starts) + num_cols = len(w_starts) + num_tiles = num_rows * num_cols + + if num_tiles < 2: + return orig_decode(z, return_dict=False)[0] + + blend_extent = int(tile_sample_min_size * tile_overlap_factor) + row_limit = int(tile_sample_min_size - blend_extent) + + # Decide which ranks actively decode tiles. + active = rank < pp_size + + local_tiles: list[torch.Tensor] = [] + local_meta: list[tuple[int, int, int]] = [] + + tile_id = 0 + for i in h_starts: + for j in w_starts: + # Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile. + tile_rank = (tile_id + 1) % pp_size + if active and (tile_rank == rank): + tile = z[:, :, i : i + tile_latent_min_size, j : j + tile_latent_min_size] + if getattr(getattr(vae, "config", None), "use_post_quant_conv", False): + tile = vae.post_quant_conv(tile) + decoded = vae.decoder(tile) + local_tiles.append(decoded) + local_meta.append((tile_id, int(decoded.shape[-2]), int(decoded.shape[-1]))) + tile_id += 1 + + # Gather per-rank tile counts. + count_tensor = torch.tensor([len(local_tiles)], device=z.device, dtype=torch.int64) + if rank == 0: + count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)] + else: + count_gather = None + dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group) + max_count = 0 + if rank == 0: + counts = [int(t.item()) for t in count_gather] # type: ignore[arg-type] + max_count = max(counts) if counts else 0 + max_count_tensor = torch.tensor([max_count], device=z.device, dtype=torch.int64) + dist.broadcast(max_count_tensor, src=0, group=group) + max_count = int(max_count_tensor.item()) + + out_channels = _get_vae_out_channels(vae) + + # Prepare padded metadata + tiles for gather. + meta_tensor = torch.full((max_count, 3), -1, device=z.device, dtype=torch.int64) + tile_tensor = torch.zeros( + (max_count, z.shape[0], out_channels, tile_sample_min_size, tile_sample_min_size), + device=z.device, + dtype=z.dtype, + ) + for idx, (tile_id, h, w) in enumerate(local_meta): + meta_tensor[idx, 0] = tile_id + meta_tensor[idx, 1] = h + meta_tensor[idx, 2] = w + tile_tensor[idx, :, :, :h, :w] = local_tiles[idx] + + if rank == 0: + meta_gather = [torch.empty_like(meta_tensor) for _ in range(world_size)] + tile_gather = [torch.empty_like(tile_tensor) for _ in range(world_size)] + else: + meta_gather = None + tile_gather = None + + dist.gather(meta_tensor, gather_list=meta_gather, dst=0, group=group) + dist.gather(tile_tensor, gather_list=tile_gather, dst=0, group=group) + + if rank != 0: + return torch.empty(0, device=z.device, dtype=z.dtype) + + # Reconstruct the full tile grid on rank0. + tile_map: dict[int, torch.Tensor] = {} + for src_rank in range(world_size): + meta_src = meta_gather[src_rank] # type: ignore[index] + tiles_src = tile_gather[src_rank] # type: ignore[index] + for idx in range(max_count): + tid = int(meta_src[idx, 0].item()) + if tid < 0: + continue + h = int(meta_src[idx, 1].item()) + w = int(meta_src[idx, 2].item()) + tile_map[tid] = tiles_src[idx, :, :, :h, :w] + + rows: list[list[torch.Tensor]] = [] + for r in range(num_rows): + row: list[torch.Tensor] = [] + for c in range(num_cols): + tid = r * num_cols + c + row.append(tile_map[tid]) + rows.append(row) + + result_rows: list[torch.Tensor] = [] + for i, row in enumerate(rows): + result_row: list[torch.Tensor] = [] + for j, tile in enumerate(row): + if i > 0: + tile = vae.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = vae.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + return torch.cat(result_rows, dim=2) + + +def _distributed_patch_decode( + *, + vae: Any, + orig_decode: Callable[..., Any], + z: torch.Tensor, + group: dist.ProcessGroup, + vae_patch_parallel_size: int, + vae_scale_factor: int, +) -> torch.Tensor: + """Decode one spatial block per rank, then stitch RGB blocks on rank0. + + Intended for sizes where diffusers tiling would not kick in, so we can still + reduce the per-rank VAE decode activation peak. Each rank decodes a core + block with a small latent-space halo, then crops to the core and gathers the + RGB blocks to rank0 for final stitching. + """ + world_size, rank, pp_size = _get_world_rank_pp_size(group, vae_patch_parallel_size) + if pp_size <= 1: + return orig_decode(z, return_dict=False)[0] + + tile_params = _get_vae_tile_params(vae) + if tile_params is None: + return orig_decode(z, return_dict=False)[0] + tile_latent_min_size, tile_overlap_factor = tile_params + + overlap_latent = int(tile_latent_min_size * float(tile_overlap_factor)) + halo_base = max(0, overlap_latent // 2) + + # Only ranks < pp_size participate in decoding. Others send empty tensors. + active = rank < pp_size + + bsz, _, latent_h, latent_w = z.shape + scale = int(vae_scale_factor) + out_h = latent_h * scale + out_w = latent_w * scale + + out_channels = _get_vae_out_channels(vae) + + local_core = torch.empty(0, device=z.device, dtype=z.dtype) + local_h = 0 + local_w = 0 + + grid_rows, grid_cols = _factor_pp_grid(pp_size) + + if active: + patch_idx = rank + patch_row = patch_idx // grid_cols + patch_col = patch_idx % grid_cols + + h0 = (patch_row * latent_h) // grid_rows + h1 = ((patch_row + 1) * latent_h) // grid_rows + w0 = (patch_col * latent_w) // grid_cols + w1 = ((patch_col + 1) * latent_w) // grid_cols + + core_h = max(0, h1 - h0) + core_w = max(0, w1 - w0) + if core_h == 0 or core_w == 0: + local_core = torch.empty(0, device=z.device, dtype=z.dtype) + else: + halo = max(halo_base, min(core_h, core_w) // 2) + ph0 = max(0, h0 - halo) + ph1 = min(latent_h, h1 + halo) + pw0 = max(0, w0 - halo) + pw1 = min(latent_w, w1 + halo) + + tile = z[:, :, ph0:ph1, pw0:pw1] + if getattr(getattr(vae, "config", None), "use_post_quant_conv", False): + tile = vae.post_quant_conv(tile) + decoded = vae.decoder(tile) + + ch0 = (h0 - ph0) * scale + cw0 = (w0 - pw0) * scale + ch1 = ch0 + core_h * scale + cw1 = cw0 + core_w * scale + local_core = decoded[:, :, ch0:ch1, cw0:cw1] + + local_h = int(local_core.shape[-2]) if local_core.numel() else 0 + local_w = int(local_core.shape[-1]) if local_core.numel() else 0 + + # Gather block shapes. + shape_tensor = torch.tensor([local_h, local_w], device=z.device, dtype=torch.int64) + if rank == 0: + shape_gather = [torch.empty_like(shape_tensor) for _ in range(world_size)] + else: + shape_gather = None + dist.gather(shape_tensor, gather_list=shape_gather, dst=0, group=group) + + max_h = 0 + max_w = 0 + if rank == 0: + shapes = [tuple(int(x.item()) for x in t) for t in shape_gather] # type: ignore[arg-type] + max_h = max((h for h, _ in shapes), default=0) + max_w = max((w for _, w in shapes), default=0) + + max_hw_tensor = torch.tensor([max_h, max_w], device=z.device, dtype=torch.int64) + dist.broadcast(max_hw_tensor, src=0, group=group) + max_h = int(max_hw_tensor[0].item()) + max_w = int(max_hw_tensor[1].item()) + + # Pad local block for gather. + if max_h == 0 or max_w == 0: + padded = torch.empty(0, device=z.device, dtype=z.dtype) + else: + padded = torch.zeros((bsz, out_channels, max_h, max_w), device=z.device, dtype=z.dtype) + if local_h and local_w: + padded[:, :, :local_h, :local_w] = local_core + + if rank == 0: + block_gather = [torch.empty_like(padded) for _ in range(world_size)] + else: + block_gather = None + dist.gather(padded, gather_list=block_gather, dst=0, group=group) + + if rank != 0: + return torch.empty(0, device=z.device, dtype=z.dtype) + + # Stitch on rank0. + out = torch.empty((bsz, out_channels, out_h, out_w), device=z.device, dtype=z.dtype) + + for patch_idx in range(pp_size): + src_rank = patch_idx + patch_row = patch_idx // grid_cols + patch_col = patch_idx % grid_cols + + h0 = (patch_row * latent_h) // grid_rows + h1 = ((patch_row + 1) * latent_h) // grid_rows + w0 = (patch_col * latent_w) // grid_cols + w1 = ((patch_col + 1) * latent_w) // grid_cols + + ph = (h1 - h0) * scale + pw = (w1 - w0) * scale + if ph <= 0 or pw <= 0: + continue + + tile = block_gather[src_rank] # type: ignore[index] + out[:, :, h0 * scale : h1 * scale, w0 * scale : w1 * scale] = tile[:, :, :ph, :pw] + + return out + + +class VaePatchParallelism: + """Patch/tile-parallel VAE decode wrapper. + + This is meant to wrap `vae.decode` as an instance-level override + so pipelines don't need model-specific code paths. + """ + + def __init__( + self, + vae: Any, + *, + vae_patch_parallel_size: int, + group_getter: Callable[[], dist.ProcessGroup], + ) -> None: + self._vae = vae + self._vae_patch_parallel_size = int(vae_patch_parallel_size) + self._group_getter = group_getter + + self._vae_scale_factor = _get_vae_spatial_scale_factor(vae) + self._orig_decode = vae.decode + + def decode(self, z: torch.Tensor, return_dict: bool = True, *args: Any, **kwargs: Any): + # Keep the original path for unsupported VAE types / shapes. + if z.ndim != 4: + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + if self._vae_patch_parallel_size <= 1 or not dist.is_initialized(): + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + if not getattr(self._vae, "use_tiling", False): + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + try: + group = self._group_getter() + except Exception: + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + pp_size = min(int(self._vae_patch_parallel_size), int(world_size)) + if pp_size <= 1: + return self._orig_decode(z, return_dict=return_dict, *args, **kwargs) + + # Match diffusers' condition for when VAE tiling would be used. + tile_latent_min_size = getattr(self._vae, "tile_latent_min_size", None) + if tile_latent_min_size is None: + decoded = _distributed_tiled_decode( + vae=self._vae, + orig_decode=self._orig_decode, + z=z, + group=group, + vae_patch_parallel_size=pp_size, + ) + else: + should_tile = (z.shape[-1] > tile_latent_min_size) or (z.shape[-2] > tile_latent_min_size) + if should_tile: + decoded = _distributed_tiled_decode( + vae=self._vae, + orig_decode=self._orig_decode, + z=z, + group=group, + vae_patch_parallel_size=pp_size, + ) + else: + decoded = _distributed_patch_decode( + vae=self._vae, + orig_decode=self._orig_decode, + z=z, + group=group, + vae_patch_parallel_size=pp_size, + vae_scale_factor=self._vae_scale_factor, + ) + + if rank == 0 and decoded.numel() == 0: + logger.warning("VAE patch parallel decode produced empty output on rank0; falling back to vae.decode.") + decoded = self._orig_decode(z, return_dict=False, *args, **kwargs)[0] + if rank == 0 and decoded.dtype != z.dtype: + decoded = decoded.to(dtype=z.dtype) + if rank == 0 and not decoded.is_contiguous(): + decoded = decoded.contiguous() + + shape_tensor = torch.empty((4,), device=z.device, dtype=torch.int64) + if rank == 0: + shape_tensor.copy_(torch.tensor(tuple(decoded.shape), device=z.device, dtype=torch.int64)) + dist.broadcast(shape_tensor, src=0, group=group) + + if rank != 0: + decoded = torch.empty(tuple(int(x) for x in shape_tensor.tolist()), device=z.device, dtype=z.dtype) + dist.broadcast(decoded, src=0, group=group) + + if not return_dict: + return (decoded,) + + from diffusers.models.autoencoders.vae import DecoderOutput + + return DecoderOutput(sample=decoded) + + +def maybe_wrap_vae_decode_with_patch_parallelism( + pipeline: Any, + *, + vae_patch_parallel_size: int, + group_getter: Callable[[], dist.ProcessGroup], +) -> None: + """Wrap a diffusers-style pipeline's `vae.decode` with patch/tile parallel decode.""" + if vae_patch_parallel_size <= 1: + return + + vae = getattr(pipeline, "vae", None) + if vae is None or not hasattr(vae, "decode"): + return + try: + from diffusers.models.autoencoders import AutoencoderKL + except Exception: + return + if not isinstance(vae, AutoencoderKL): + return + + if getattr(vae, "_vllm_vae_patch_parallel_installed", False): + return + + wrapper = VaePatchParallelism( + vae, + vae_patch_parallel_size=vae_patch_parallel_size, + group_getter=group_getter, + ) + + vae._vllm_vae_patch_parallel_installed = True # type: ignore[attr-defined] + vae._vllm_vae_patch_parallel_original_decode = vae.decode # type: ignore[attr-defined] + vae.decode = wrapper.decode # type: ignore[assignment] diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 3d236ac31b5..173c576a243 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -113,6 +113,11 @@ } ) +_VAE_PATCH_PARALLEL_ALLOWLIST = { + # Only enable for models we have validated end-to-end. + "ZImagePipeline", +} + def initialize_model( od_config: OmniDiffusionConfig, @@ -137,12 +142,46 @@ def initialize_model( model_class = DiffusionModelRegistry._try_load_model_cls(od_config.model_class_name) if model_class is not None: model = model_class(od_config=od_config) + + vae_pp_size = od_config.parallel_config.vae_patch_parallel_size + if vae_pp_size > 1 and od_config.model_class_name not in _VAE_PATCH_PARALLEL_ALLOWLIST: + logger.warning( + "vae_patch_parallel_size=%d is set but VAE patch parallelism is only enabled for %s; ignoring.", + vae_pp_size, + sorted(_VAE_PATCH_PARALLEL_ALLOWLIST), + ) + if ( + vae_pp_size > 1 + and od_config.model_class_name in _VAE_PATCH_PARALLEL_ALLOWLIST + and not od_config.vae_use_tiling + ): + logger.info( + "vae_patch_parallel_size=%d requires vae_use_tiling; automatically enabling it.", + vae_pp_size, + ) + od_config.vae_use_tiling = True + # Configure VAE memory optimization settings from config if hasattr(model.vae, "use_slicing"): model.vae.use_slicing = od_config.vae_use_slicing if hasattr(model.vae, "use_tiling"): model.vae.use_tiling = od_config.vae_use_tiling + if ( + vae_pp_size > 1 + and hasattr(model, "vae") + and od_config.model_class_name in _VAE_PATCH_PARALLEL_ALLOWLIST + and od_config.vae_use_tiling + ): + from vllm_omni.diffusion.distributed.parallel_state import get_dit_group + from vllm_omni.diffusion.distributed.vae_patch_parallel import maybe_wrap_vae_decode_with_patch_parallelism + + maybe_wrap_vae_decode_with_patch_parallelism( + model, + vae_patch_parallel_size=vae_pp_size, + group_getter=get_dit_group, + ) + # Apply sequence parallelism if enabled # This follows diffusers' pattern where enable_parallelism() is called # at model loading time, not inside individual model files diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index a1e467a88ac..64ffab3a1ab 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -11,7 +11,17 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict -from vllm.multimodal.processing.context import set_request_id + +try: + from vllm.multimodal.processing import set_request_id +except ImportError: # vllm without set_request_id (older releases) + from contextlib import contextmanager + + @contextmanager + def set_request_id(_request_id: str): + yield + + from vllm.multimodal.utils import argsort_mm_positions from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams