Skip to content

[Perf][Bagel] Avoid per-step device syncs in Bagel img2img#3987

Merged
Gaohan123 merged 3 commits into
vllm-project:mainfrom
natureofnature:bagel-i2i-device-sync
Jun 2, 2026
Merged

[Perf][Bagel] Avoid per-step device syncs in Bagel img2img#3987
Gaohan123 merged 3 commits into
vllm-project:mainfrom
natureofnature:bagel-i2i-device-sync

Conversation

@natureofnature

@natureofnature natureofnature commented May 29, 2026

Copy link
Copy Markdown
Contributor

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

AR

The Bagel img2img path forces implicit GPU→CPU synchronizations in two hot loops, leaving the GPU idle while the CPU stalls:
Screenshot from 2026-05-29 05-23-04

DIT

Screenshot from 2026-05-29 05-24-10
  • AR stage (model_executor/models/bagel/bagel.py) — the MoT routing calls vae_mask.any() / ~vae_mask.any() (each is an implicit .item() device sync), and _adjust_positions_for_img2img indexes the CUDA positions tensor element-by-element in a Python loop (one sync per token).
  • DiT stage (diffusion/models/bagel/bagel_transformer.py) — the denoise loops iterate for t in timesteps where timesteps is a CUDA tensor (so t is a 0-d tensor → building the timestep tensor and the cfg_interval comparison sync eachstep), and attention uses Python max()/sum() over length tensors (one sync per element).

Fix

  • Cache the VAE-mask occupancy once per request as plain bools; the per-layer routing reads the bools instead of calling .any().
  • Copy positions to the host once before the boundary loop.
  • Iterate timesteps.tolist() so t is a Python float (timesteps[i] tensor is still used for the scheduler).
  • Use tensor reductions (query_lens.max(), query_lens.sum()) instead of Python max()/sum() over tensors.

Test Plan

  • Correctness: ran i2i before/after on the same input + seed at 256² and 1024²; generated PNGs are md5-identical between baseline and patched (changes are numerically identical, not approximations).
  • Performance: split serving (stage-0 AR / stage-1 DiT on separate GPUs, RDMA connector), 1 warmup + 5 measured i2i requests per config; comparedmetrics.stage_durations (stage_0_gen_ms / stage_1_gen_ms) baseline vs patched.
  • Smoke: service boots and serves t2i + i2i (HTTP 200) with the patched code.
  • Prompt: Change the color to blue
input_small

Test Result

Split serving (RDMA connector, 15 steps, i2i), stage_0 = AR, stage_1 = DiT, All changes are numerically identical — outputs are byte-for-byte unchanged (md5 of generated images matches before/after at both 256² and 1024²).
mean of 5 measured runs (H800, AR/DIT disaggregation):

resolution stage before after delta
256×256 AR 1064 ms 890 ms −16.4%
256×256 DiT 755 ms 724 ms −4.1%
1024×1024 AR 2265 ms 1969 ms −13.1%
1024×1024 DiT 7714 ms 7362 ms −4.6%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@natureofnature

Copy link
Copy Markdown
Contributor Author

@princepride PTAL

Cache the VAE-mask flags and iterate host-side values so the AR MoT routing
and the DiT denoise loop stop forcing GPU->CPU syncs on every layer/step.

Signed-off-by: natureofnature <wzliu@connect.hku.hk>
@natureofnature natureofnature force-pushed the bagel-i2i-device-sync branch from 1e68110 to 8ddca2e Compare May 29, 2026 09:52
@hsliuustc0106 hsliuustc0106 added ready label to trigger buildkite CI merge-test label to trigger buildkite merge test CI labels Jun 1, 2026
@Gaohan123 Gaohan123 added this to the v0.22.0 milestone Jun 1, 2026

@Gaohan123 Gaohan123 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@Gaohan123 Gaohan123 merged commit 1fb423e into vllm-project:main Jun 2, 2026
7 of 8 checks passed
86MaxCao pushed a commit to 86MaxCao/vllm-omni that referenced this pull request Jun 4, 2026
akshatvishu pushed a commit to akshatvishu/vllm-omni that referenced this pull request Jun 13, 2026
…ect#3987)

Signed-off-by: natureofnature <wzliu@connect.hku.hk>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-test label to trigger buildkite merge test CI ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants