Skip to content

Support VAE parallel for Bagel #3982

Merged
princepride merged 4 commits into
vllm-project:mainfrom
lsyyysky:main
Jun 2, 2026
Merged

Support VAE parallel for Bagel #3982
princepride merged 4 commits into
vllm-project:mainfrom
lsyyysky:main

Conversation

@lsyyysky

@lsyyysky lsyyysky commented May 29, 2026

Copy link
Copy Markdown
Contributor

Purpose

Add VAE Patch Parallelism support for the Bagel (BAGEL-7B-MoT) diffusion model.

This PR lets Bagel split the latent into spatial tiles and distribute them across the DiT process group, so each rank only materializes the activations for its own tiles instead of the whole image — lowering per-GPU peak memory at high resolution.

Key points:

  • Introduce DistributedAutoEncoder(AutoEncoder, DistributedVaeMixin) in vllm_omni/diffusion/models/bagel/autoencoder.py, implementing split / exec / merge for both decode and encode (with overlap blending to avoid seams).
  • BagelPipeline now instantiates DistributedAutoEncoder so the DiT stage can run distributed e READMEs (scope, requirements, deploy YAML / CLI examples, verification via startup logs).

Scope:

Topology VAE patch parallel
Single-stage (DiT only) Supported on stage 0 (BagelPipeline + DistributedAutoEncoder)
Two-stage Supported on stage 1 (DiT) only; stage 0 (Thinker) uses the encoder-only VAE and is unrelated

Test Plan

Hardware: 2× GPU, BAGEL-7B-MoT at /data/Bagel/BAGEL-7B-MoT.

  1. End-to-end (single-stage DiT, text2img, 1024×1024, 20 steps, seed=42) — compare tensor_parallel_size=2 only vs tensor_parallel_size=2 + vae_patch_parallel_size=2 (vae_use_tiling=true). Metrics: per-request peak GPU memory (Peak GPU memory (this request)) and generation latency (stage_0_gen_ms).

    Single-stage deploy used for both runs (only vae_patch_parallel_size / vae_use_tiling differ):

    pipeline: bagel_single_stage
    stages:
      - stage_id: 0
        enforce_eager: true
        trust_remote_code: true
        devices: "0,1"
        vae_use_tiling: true            # VAE PP run only
        parallel_config:
          tensor_parallel_size: 2
          vae_patch_parallel_size: 2    # 1 for the TP-only baseline
  2. Correctness: confirm Bagel VAE decode running with distributed executor appears in logs when enabled, and generated images are valid.

Test Result

End-to-end (1024×1024, 20 steps) — inference phase, rank0(A100)

Config TP VAE PP Latency (s) Peak reserved Peak allocated vs TP-only
TP=2 only 2 1 (off) 16.63 20.0 GB 16.88 GB baseline
TP=2 + VAE PP=2 2 2 (on) 16.63 16.9 GB 16.15 GB -15%

The benefit grows with resolution: negligible at 512×512 (Transformer-bound), ~3 GB at 1024×1024 end-to-end


@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

lsyyysky added 2 commits May 29, 2026 09:23
Signed-off-by: siyuan.lei <siyuanlei37@gmail.com>
Signed-off-by: siyuan.lei <siyuanlei37@gmail.com>
id="parallel_hsdp_2",
marks=HSDP_2_FEATURE_MARKS,
),
# Tensor Parallelism (TP) + VAE Patch Parallelism (size=2)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new TP + VAE-PP setup is not stage-local in deploy-config mode, so --tensor-parallel-size 2 also leaks into stage 0 while that stage is still pinned to devices: "0"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

"tile_latent_stride_height": tile_latent_stride_height,
"tile_latent_stride_width": tile_latent_stride_width,
},
output_dtype=x.dtype,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed encode() path uses x.dtype for gather/broadcast buffers, so Bagel img2img can encode latents under autocast and still repack them into float32 buffers unnecessarily

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

lsyyysky added 2 commits June 2, 2026 06:59
Signed-off-by: siyuan.lei <siyuanlei37@gmail.com>
@princepride princepride enabled auto-merge (squash) June 2, 2026 10:24
@princepride princepride added the ready label to trigger buildkite CI label Jun 2, 2026
@princepride princepride merged commit bd37f3c into vllm-project:main Jun 2, 2026
8 checks passed
86MaxCao pushed a commit to 86MaxCao/vllm-omni that referenced this pull request Jun 4, 2026
Signed-off-by: siyuan.lei <siyuanlei37@gmail.com>
akshatvishu pushed a commit to akshatvishu/vllm-omni that referenced this pull request Jun 13, 2026
Signed-off-by: siyuan.lei <siyuanlei37@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants