From c1779e1f0916871276d6f427326b8984f74eaa6c Mon Sep 17 00:00:00 2001 From: "siyuan.lei" Date: Fri, 29 May 2026 08:25:12 +0000 Subject: [PATCH 1/3] [Feat] support VAE parallel for Bagel Signed-off-by: siyuan.lei --- examples/offline_inference/bagel/README.md | 95 +++++++ examples/online_serving/bagel/README.md | 63 +++++ .../online_serving/test_bagel_expansion.py | 20 +- .../diffusion/models/bagel/autoencoder.py | 236 ++++++++++++++++++ .../diffusion/models/bagel/pipeline_bagel.py | 4 +- 5 files changed, 415 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md index 9c2e31ff39f..f2685d56fd3 100644 --- a/examples/offline_inference/bagel/README.md +++ b/examples/offline_inference/bagel/README.md @@ -195,6 +195,101 @@ stages: devices: "0,1" ``` +### VAE Patch Parallelism + +[VAE Patch Parallelism](https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/diffusion/parallelism/vae_patch_parallel.html) splits Bagel VAE **decode/encode** tiles across multiple GPUs on the **DiT stage**, reducing **per-GPU peak memory during VAE decode**. Use it when high-resolution `text2img` or `img2img` hits VAE OOM or large decode spikes. + +**Bagel-specific notes:** + +- Implemented in `BagelPipeline` via `DistributedAutoEncoder` (DiT stage only). +- **Single-stage** is the simplest path: one DiT process with TP + VAE patch parallel. +- **Two-stage**: enable on **stage 1 (DiT)** only; stage 0 (Thinker) keeps encoder-only `VAEEncoder` and does not use VAE patch parallel. +- You need a DiT `world_size` ≥ `vae_patch_parallel_size` (typically `tensor_parallel_size=2` on that stage). VAE PP reuses the DiT process group; it is not a standalone second-GPU VAE worker. + +**Single-stage via deploy YAML** (recommended for `end2end.py`): + +```yaml +pipeline: bagel_single_stage +async_chunk: false + +stages: + - stage_id: 0 + max_num_batched_tokens: 32768 + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + devices: "0,1" + vae_use_tiling: true + parallel_config: + tensor_parallel_size: 2 + vae_patch_parallel_size: 2 + default_sampling_params: + seed: 52 +``` + +```bash +cd examples/offline_inference/bagel + +CUDA_VISIBLE_DEVICES=0,1 python end2end.py \ + --model /path/to/BAGEL-7B-MoT \ + --deploy-config /path/to/bagel_single_stage_vae_pp.yaml \ + --modality text2img \ + --prompts "A cute cat" \ + --steps 10 \ + --output ./out_vae_pp +``` + +**Single-stage via `Omni` kwargs** (same flags as online serving): + +```python +from vllm_omni.entrypoints.omni import Omni + +omni = Omni( + model="ByteDance-Seed/BAGEL-7B-MoT", + deploy_config="vllm_omni/deploy/bagel_single_stage.yaml", + tensor_parallel_size=2, + vae_patch_parallel_size=2, + vae_use_tiling=True, +) +# Then call omni.generate(...) as in end2end.py +``` + +**Two-stage (VAE PP on DiT only):** + +```yaml +stages: + - stage_id: 0 + devices: "0" + # AR Thinker — no vae_patch_parallel here + + - stage_id: 1 + devices: "0,1" + vae_use_tiling: true + parallel_config: + tensor_parallel_size: 2 + vae_patch_parallel_size: 2 +``` + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --deploy-config /path/to/bagel_vae_pp.yaml \ + --modality text2img \ + --prompts "A cute cat" +``` + +**Startup log checks:** + +```text +INFO ... vae_patch_parallel_size=2 requires vae_use_tiling; automatically enabling it. +``` + +| Setting | Role | +| :------ | :--- | +| `parallel_config.tensor_parallel_size` | DiT world size / TP (must be ≥ `vae_patch_parallel_size`) | +| `parallel_config.vae_patch_parallel_size` | Number of ranks for distributed VAE tiles (`1` = off) | +| `vae_use_tiling` | Enable spatial tiling (auto-enabled when `vae_patch_parallel_size > 1`) | + #### Hybrid Sharded Data Parallel (HSDP) For larger Bagel deployments on multiple GPUs, you can enable HSDP (Hybrid Sharded Data Parallel) by modifying the stage configuration (for example, [`bagel.yaml`](../../../vllm_omni/deploy/bagel.yaml)). HSDP shards transformer weights across GPUs to reduce per-GPU memory usage. diff --git a/examples/online_serving/bagel/README.md b/examples/online_serving/bagel/README.md index ea8abe43f8e..b7a54da356c 100644 --- a/examples/online_serving/bagel/README.md +++ b/examples/online_serving/bagel/README.md @@ -68,6 +68,67 @@ vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --tensor-parallel-size Or set `tensor_parallel_size` per stage in a custom deploy YAML. +### VAE Patch Parallelism + +[VAE Patch Parallelism](https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/diffusion/parallelism/vae_patch_parallel.html) distributes Bagel VAE **decode/encode** across multiple GPUs by splitting latent tiles. It lowers **per-GPU peak memory during VAE decode**, which helps high-resolution `text2img` / `img2img` when VAE becomes a bottleneck. + +**Scope for Bagel:** + +| Topology | VAE patch parallel | +| :------- | :----------------- | +| **Single-stage** (DiT only) | Supported on stage 0 (`BagelPipeline` + `DistributedAutoEncoder`) | +| **Two-stage** | Supported on **stage 1 (DiT)** only; stage 0 (Thinker) uses encoder-only VAE and is unrelated | + +**Requirements:** + +- `vae_patch_parallel_size > 1` and a distributed VAE (`DistributedAutoEncoder` on the DiT pipeline). +- The DiT process group must have at least `vae_patch_parallel_size` ranks. In practice this means the diffusion stage `world_size` must be ≥ 2 (commonly `tensor_parallel_size=2` on that stage). +- `vae_use_tiling` must be enabled. If you set `vae_patch_parallel_size > 1` and omit tiling, the registry auto-enables `vae_use_tiling` at startup. + +VAE patch parallel **reuses the DiT process group** (`dit_group`); it does not create a separate VAE-only worker pool. It is not a substitute for single-GPU VAE tiling (`vae_pp=1`). + +**Online serving (single-stage, 2 GPUs):** + +```bash +CUDA_VISIBLE_DEVICES=0,1 vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \ + --deploy-config vllm_omni/deploy/bagel_single_stage.yaml \ + --tensor-parallel-size 2 \ + --vae-patch-parallel-size 2 \ + --vae-use-tiling +``` + +**Online serving (two-stage, VAE PP on DiT stage 1):** use a custom deploy YAML, for example: + +```yaml +stages: + - stage_id: 0 + devices: "0" + # Thinker (AR) — no VAE patch parallel here + + - stage_id: 1 + devices: "0,1" + vae_use_tiling: true + parallel_config: + tensor_parallel_size: 2 + vae_patch_parallel_size: 2 +``` + +```bash +vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \ + --deploy-config /path/to/bagel_vae_pp.yaml +``` + +**Verify it is active** (check server logs at startup): + +```text +INFO ... vae_patch_parallel_size=2 requires vae_use_tiling; automatically enabling it. +``` + +| CLI flag | Default | Description | +| :------- | :------ | :---------- | +| `--vae-patch-parallel-size` | `1` | Number of DiT ranks used for VAE tile parallelism. Set to `2` or higher to enable. Should be ≤ DiT process group size (typically match `--tensor-parallel-size` on the diffusion stage). | +| `--vae-use-tiling` | off | Enable VAE spatial tiling. Required for VAE patch parallel (auto-enabled when `vae_patch_parallel_size > 1`). | + #### Hybrid Sharded Data Parallel (HSDP) For larger Bagel deployments on multiple GPUs, you can enable HSDP (Hybrid Sharded Data Parallel) by modifying the stage configuration (for example, [`bagel.yaml`](../../../vllm_omni/deploy/bagel.yaml)). HSDP shards transformer weights across GPUs to reduce per-GPU memory usage. @@ -323,6 +384,8 @@ python openai_chat_client.py \ | `stages[].gpu_memory_utilization` | per-stage | Fraction of GPU memory to use | | `stages[].enforce_eager` | per-stage | Disable CUDA graphs | | `stages[].tensor_parallel_size` | per-stage | TP degree for this stage | +| `stages[].parallel_config.vae_patch_parallel_size` | per-stage (DiT) | VAE tile parallelism degree (DiT stage only) | +| `stages[].vae_use_tiling` | per-stage (DiT) | Enable VAE tiling (required for VAE patch parallel) | | `connectors` | top-level | Define available connector instances (SHM, Mooncake) | | `platforms` | top-level | Platform-specific overrides (e.g. `xpu`) | diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index 2feaa3f079c..8ba1393b07c 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -43,7 +43,8 @@ def _get_diffusion_feature_cases(model: str): """Return L4 diffusion feature cases for Bagel. TeaCache, Cache-DiT, CFG-Parallel, - Ulysses-SP, Ring-Attention, Layerwise Offloading. + Ulysses-SP, Ring-Attention, Layerwise Offloading, + Hybrid Sharded Data Parallel, Tensor Parallelism, VAE Patch Parallelism. """ return [ @@ -135,6 +136,22 @@ def _get_diffusion_feature_cases(model: str): id="parallel_hsdp_2", marks=HSDP_2_FEATURE_MARKS, ), + # Tensor Parallelism (TP) + VAE Patch Parallelism (size=2) + pytest.param( + OmniServerParams( + model=model, + stage_config_path=BAGEL_PARALLEL_2_DEPLOY, + server_args=[ + "--tensor-parallel-size", + "2", + "--vae-patch-parallel-size", + "2", + "--vae-use-tiling", + ], + ), + id="tp_vae_patch_parallel_2", + marks=PARALLEL_2_FEATURE_MARKS, + ), ] @@ -157,6 +174,7 @@ def test_bagel( - Ring-Attention (degree=2) - Layerwise Offloading - Hybrid Sharded Data Parallel (size=2) + - Tensor Parallelism (TP) + VAE Patch Parallelism (size=2) Validation is delegated to assert_diffusion_response in tests/helpers/assertions.py, which checks output dimensions and basic correctness. diff --git a/vllm_omni/diffusion/models/bagel/autoencoder.py b/vllm_omni/diffusion/models/bagel/autoencoder.py index 0980f25cd19..9442fcaf101 100644 --- a/vllm_omni/diffusion/models/bagel/autoencoder.py +++ b/vllm_omni/diffusion/models/bagel/autoencoder.py @@ -14,6 +14,16 @@ import torch from einops import rearrange from torch import Tensor, nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import ( + DistributedOperator, + DistributedVaeMixin, + GridSpec, + TileTask, +) + +logger = init_logger(__name__) @dataclass @@ -322,3 +332,229 @@ def decode(self, z: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) + + +class DistributedAutoEncoder(AutoEncoder, DistributedVaeMixin): + def __init__(self, params: AutoEncoderParams): + super().__init__(params) + self.init_distributed() + + self.use_tiling = False + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.spatial_compression_ratio = params.downsample + + def blend_v(self, above: Tensor, current: Tensor, blend_extent: int) -> Tensor: + blend_extent = min(above.shape[-2], current.shape[-2], blend_extent) + if blend_extent <= 0: + return current + + for y in range(blend_extent): + alpha = y / blend_extent + current[:, :, y, :] = above[:, :, -blend_extent + y, :] * (1 - alpha) + current[:, :, y, :] * alpha + return current + + def blend_h(self, left: Tensor, current: Tensor, blend_extent: int) -> Tensor: + blend_extent = min(left.shape[-1], current.shape[-1], blend_extent) + if blend_extent <= 0: + return current + + for x in range(blend_extent): + alpha = x / blend_extent + current[:, :, :, x] = left[:, :, :, -blend_extent + x] * (1 - alpha) + current[:, :, :, x] * alpha + return current + + def decode_tile_split(self, z: Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, latent_height, latent_width = z.shape + + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + tiletask_list = [] + for i in range(0, latent_height, tile_latent_stride_height): + for j in range(0, latent_width, tile_latent_stride_width): + tile = z[ + :, + :, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + tiletask_list.append( + TileTask( + tile_id=len(tiletask_list), + grid_coord=(i // tile_latent_stride_height, j // tile_latent_stride_width), + tensor=tile, + workload=tile.shape[-2] * tile.shape[-1], + ) + ) + grid_spec = GridSpec( + split_dims=(2, 3), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={ + "sample_height": sample_height, + "sample_width": sample_width, + "blend_height": blend_height, + "blend_width": blend_width, + }, + output_dtype=z.dtype, + ) + return tiletask_list, grid_spec + + def decode_tile_exec(self, task: TileTask) -> Tensor: + tile = task.tensor + tile = tile / self.scale_factor + self.shift_factor + return self.decoder(tile) + + def decode_tile_merge(self, coord_tensor_map: dict[tuple[int, ...], Tensor], grid_spec: GridSpec) -> Tensor: + grid_h, grid_w = grid_spec.grid_shape + result_rows = [] + + for i in range(grid_h): + result_row = [] + for j in range(grid_w): + tile = coord_tensor_map[(i, j)] + if i > 0: + tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"]) + if j > 0: + tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"]) + + result_row.append( + tile[ + :, + :, + : self.tile_sample_stride_height, + : self.tile_sample_stride_width, + ] + ) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=-2)[ + :, + :, + : grid_spec.tile_spec["sample_height"], + : grid_spec.tile_spec["sample_width"], + ] + return dec + + def decode(self, z: Tensor) -> Tensor: + if not self.is_distributed_enabled(): + return super().decode(z) + + logger.debug("Bagel VAE decode running with distributed executor") + return self.distributed_executor.execute( + z, + DistributedOperator( + split=self.decode_tile_split, + exec=self.decode_tile_exec, + merge=self.decode_tile_merge, + ), + broadcast_result=True, + ) + + def encode_tile_split(self, x: Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, height, width = x.shape + + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tiletask_list = [] + for i in range(0, height, self.tile_sample_stride_height): + for j in range(0, width, self.tile_sample_stride_width): + tile = x[ + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tiletask_list.append( + TileTask( + tile_id=len(tiletask_list), + grid_coord=(i // self.tile_sample_stride_height, j // self.tile_sample_stride_width), + tensor=tile, + workload=tile.shape[-2] * tile.shape[-1], + ) + ) + grid_spec = GridSpec( + split_dims=(2, 3), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={ + "latent_height": latent_height, + "latent_width": latent_width, + "blend_height": blend_height, + "blend_width": blend_width, + "tile_latent_stride_height": tile_latent_stride_height, + "tile_latent_stride_width": tile_latent_stride_width, + }, + output_dtype=x.dtype, + ) + return tiletask_list, grid_spec + + def encode_tile_exec(self, task: TileTask) -> Tensor: + tile = task.tensor + z = self.reg(self.encoder(tile)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def encode_tile_merge(self, coord_tensor_map: dict[tuple[int, ...], Tensor], grid_spec: GridSpec) -> Tensor: + grid_h, grid_w = grid_spec.grid_shape + result_rows = [] + + for i in range(grid_h): + result_row = [] + for j in range(grid_w): + tile = coord_tensor_map[(i, j)] + if i > 0: + tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"]) + if j > 0: + tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"]) + + result_row.append( + tile[ + :, + :, + : grid_spec.tile_spec["tile_latent_stride_height"], + : grid_spec.tile_spec["tile_latent_stride_width"], + ] + ) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=-2)[ + :, + :, + : grid_spec.tile_spec["latent_height"], + : grid_spec.tile_spec["latent_width"], + ] + return enc + + def encode(self, x: Tensor) -> Tensor: + if not self.is_distributed_enabled(): + return super().encode(x) + logger.debug("Bagel VAE encode running with distributed executor") + return self.distributed_executor.execute( + x, + DistributedOperator( + split=self.encode_tile_split, + exec=self.encode_tile_exec, + merge=self.encode_tile_merge, + ), + broadcast_result=True, + ) diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index a62d2a75ad4..bad6a2f8a85 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -32,7 +32,7 @@ from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific -from .autoencoder import AutoEncoder, AutoEncoderParams +from .autoencoder import AutoEncoder, AutoEncoderParams, DistributedAutoEncoder from .bagel_transformer import Bagel, NaiveCache, Qwen2MoTConfig, Qwen2MoTForCausalLM logger = init_logger(__name__) @@ -253,7 +253,7 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): ) self.transformer = self.language_model.model ae_params: AutoEncoderParams = default_ae_params() - self.vae = AutoEncoder(ae_params) + self.vae = DistributedAutoEncoder(ae_params) self.bagel = Bagel( language_model=self.language_model, From 3e09f80a27610d409ef6e25229d486c57c89b238 Mon Sep 17 00:00:00 2001 From: "siyuan.lei" Date: Fri, 29 May 2026 09:05:13 +0000 Subject: [PATCH 2/3] fix tests for Bagel VAE Patch Parallelism Signed-off-by: siyuan.lei --- tests/e2e/online_serving/test_bagel_expansion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index 8ba1393b07c..bf195e9ce9b 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -9,6 +9,8 @@ - Ulysses-SP - Ring-Attention - Layerwise Offloading +- Hybrid Sharded Data Parallel +- Tensor Parallelism + VAE Patch Parallelism assert_diffusion_response validates successful generation and the expected 512x512 resolution. From 0c5a44abfefd4eef3b022803c6a2dcc636b5b9bc Mon Sep 17 00:00:00 2001 From: "siyuan.lei" Date: Tue, 2 Jun 2026 06:59:06 +0000 Subject: [PATCH 3/3] fix issue Signed-off-by: siyuan.lei --- .../online_serving/test_bagel_expansion.py | 24 ++++++++++++------- .../diffusion/models/bagel/autoencoder.py | 2 +- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index bf195e9ce9b..a8f52c311d9 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -40,6 +40,21 @@ BAGEL_CI_DEPLOY, updates={"stages": {0: {"devices": "0"}, 1: {"devices": "0,1"}}}, ) +BAGEL_TP_VAE_PP_2_DEPLOY = modify_stage_config( + BAGEL_CI_DEPLOY, + updates={ + "stages": { + 0: {"devices": "0"}, + 1: { + "devices": "0,1", + "parallel_config": { + "tensor_parallel_size": 2, + "vae_patch_parallel_size": 2, + }, + }, + }, + }, +) def _get_diffusion_feature_cases(model: str): @@ -142,14 +157,7 @@ def _get_diffusion_feature_cases(model: str): pytest.param( OmniServerParams( model=model, - stage_config_path=BAGEL_PARALLEL_2_DEPLOY, - server_args=[ - "--tensor-parallel-size", - "2", - "--vae-patch-parallel-size", - "2", - "--vae-use-tiling", - ], + stage_config_path=BAGEL_TP_VAE_PP_2_DEPLOY, ), id="tp_vae_patch_parallel_2", marks=PARALLEL_2_FEATURE_MARKS, diff --git a/vllm_omni/diffusion/models/bagel/autoencoder.py b/vllm_omni/diffusion/models/bagel/autoencoder.py index 9442fcaf101..006e927b0f3 100644 --- a/vllm_omni/diffusion/models/bagel/autoencoder.py +++ b/vllm_omni/diffusion/models/bagel/autoencoder.py @@ -504,7 +504,7 @@ def encode_tile_split(self, x: Tensor) -> tuple[list[TileTask], GridSpec]: "tile_latent_stride_height": tile_latent_stride_height, "tile_latent_stride_width": tile_latent_stride_width, }, - output_dtype=x.dtype, + output_dtype=next(self.parameters()).dtype, ) return tiletask_list, grid_spec