[Feature][Bagel] Add CFG parallel mode#1578
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f22be8981d
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
@vllm-omni-reviewer |
1 similar comment
|
@vllm-omni-reviewer |
hsliuustc0106
left a comment
There was a problem hiding this comment.
Review Summary
This PR adds CFG parallel mode for Bagel with good structure and clean refactoring of _combine_cfg(). However, I found 2 issues that should be addressed:
Issues
1. Missing validation for cfg_parallel_size vs cfg_img_scale mismatch
If cfg_parallel_size=2 but cfg_img_scale > 1.0, image CFG is silently skipped (gathered only has 2 tensors, so cfg_img_v_t becomes None). User requested image CFG but does not get it.
2. Rank 2 crash when cfg_img_scale <= 1.0
The entry condition only checks use_cfg_text, not use_cfg_img. If cfg_parallel_size=3 but cfg_img_scale <= 1.0, rank 2 still executes and tries to use cfg_img_* parameters which are None.
Suggested Fix
Add validation at the start of _generate_image_parallel:
cfg_world_size = get_classifier_free_guidance_world_size()
if cfg_world_size >= 3 and not use_cfg_img:
raise ValueError("cfg_parallel_size=3 requires cfg_img_scale > 1.0")
if use_cfg_img and cfg_world_size < 3:
raise ValueError("Image CFG requires cfg_parallel_size >= 3")Minor
- CLI argument
--cfg-parallel-sizeonly affects img2img mode (might be intentional, but worth clarifying in help text)
Overall: Good feature implementation, just needs the validation guards for the configuration edge cases.
|
Thanks for your review! I just got free to fix it.. @hsliuustc0106 |
e8139fe to
2f5fe2e
Compare
|
@princepride @hsliuustc0106 PTAL!❤️😊 |
Add parallel CFG denoising path where 3 branches (gen, text_cfg, img_cfg) are distributed across GPUs via cfg_parallel infrastructure. - Extract _combine_cfg() for reusable CFG combination logic with renorm - Add _generate_image_parallel() for multi-GPU denoising loop - Support cfg_parallel_size=1 (batched), 2 (text CFG only), 3 (all branches) - Add validation guards for cfg_parallel_size vs cfg_img_scale consistency - Relax cfg_parallel_size validation in DiffusionParallelConfig to allow [1,2,3] Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
- Add 13 unit tests for Bagel._combine_cfg covering all renorm types, CFG scale combinations, and edge cases (CPU-only, no GPU required) - Fix dummy_run in diffusion_engine.py to set cfg_text_scale=1.0 and cfg_img_scale=1.0, preventing CFG parallel validation errors during warmup Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
…port - Add --cfg-parallel-size flag for selecting batched vs parallel mode - Add --negative-prompt support for text2img CFG - Always pass parallel_config to OmniDiffusion (cfg_parallel_size=1 default) Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
There was a problem hiding this comment.
@hsliuustc0106 @ZJY0516 I noticed that diffusion warm up happened in engine, but ar warm up happened in runner? Do we have plan to align it in the future?
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
### vllm-omni-api - Source: [PR #1724](vllm-project/vllm-omni#1724) - Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)" - Changes: - New feature: Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)" ### vllm-omni-contrib - Source: [PR #1724](vllm-project/vllm-omni#1724) - Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)" - Changes: - New feature: Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)" ### vllm-omni-api - Source: [PR #1716](vllm-project/vllm-omni#1716) - [Feature]: Add vae-patch-parallel CLI argument in online serving - Changes: - New feature: [Feature]: Add vae-patch-parallel CLI argument in online serving ### vllm-omni-contrib - Source: [PR #1716](vllm-project/vllm-omni#1716) - [Feature]: Add vae-patch-parallel CLI argument in online serving - Changes: - New feature: [Feature]: Add vae-patch-parallel CLI argument in online serving ### vllm-omni-contrib - Source: [PR #1693](vllm-project/vllm-omni#1693) - [skip CI][Docs] Add TTS model developer guide - Changes: - New feature: [skip CI][Docs] Add TTS model developer guide ### vllm-omni-audio-tts - Source: [PR #1688](vllm-project/vllm-omni#1688) - [MiMo-Audio] Bugfix tp lg than 1 - Changes: - Bug fix: [MiMo-Audio] Bugfix tp lg than 1 ### vllm-omni-distributed - Source: [PR #1688](vllm-project/vllm-omni#1688) - [MiMo-Audio] Bugfix tp lg than 1 - Changes: - Bug fix: [MiMo-Audio] Bugfix tp lg than 1 ### vllm-omni-perf - Source: [PR #1688](vllm-project/vllm-omni#1688) - [MiMo-Audio] Bugfix tp lg than 1 - Changes: - Bug fix: [MiMo-Audio] Bugfix tp lg than 1 ### vllm-omni-perf - Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech - Changes: - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech ### vllm-omni-distributed - Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech - Changes: - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech ### vllm-omni-api - Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech - Changes: - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech - Additions: - `/v1/audio/speech` ### vllm-omni-quantization - Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech - Changes: - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech ### vllm-omni-cicd - Source: [PR #1683](vllm-project/vllm-omni#1683) - [CI] Remove high concurrency tests before issue #1374 fixed. - Changes: - Bug fix: [CI] Remove high concurrency tests before issue #1374 fixed. ### vllm-omni-audio-tts - Source: [PR #1678](vllm-project/vllm-omni#1678) - Add non-async chunk support for Qwen3-TTS - Changes: - New feature: Add non-async chunk support for Qwen3-TTS ### vllm-omni-cicd - Source: [PR #1678](vllm-project/vllm-omni#1678) - Add non-async chunk support for Qwen3-TTS - Changes: - New feature: Add non-async chunk support for Qwen3-TTS ### vllm-omni-cicd - Source: [PR #1677](vllm-project/vllm-omni#1677) - Replace hard-coded cuda generator with current_omni_platform.device_type ### vllm-omni-perf - Source: [PR #1677](vllm-project/vllm-omni#1677) - Replace hard-coded cuda generator with current_omni_platform.device_type ### vllm-omni-serving - Source: [PR #1675](vllm-project/vllm-omni#1675) - [Misc] remove logits_processor_pattern this field, because vllm have … ### vllm-omni-cicd - Source: [PR #1666](vllm-project/vllm-omni#1666) - [Cleanup] Move cosyvoice3 tests to model subdirectory ### vllm-omni-audio-tts - Source: [PR #1664](vllm-project/vllm-omni#1664) - [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder - Changes: - Bug fix: [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder ### vllm-omni-cicd - Source: [PR #1664](vllm-project/vllm-omni#1664) - [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder - Changes: - Bug fix: [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder ### vllm-omni-distributed - Source: [PR #1656](vllm-project/vllm-omni#1656) - [Optimize][Qwen3-Omni] Reduce inter-packet latency in async chunk ### vllm-omni-contrib - Source: [PR #1656](vllm-project/vllm-omni#1656) - [Optimize][Qwen3-Omni] Reduce inter-packet latency in async chunk ### vllm-omni-quantization - Source: [PR #1652](vllm-project/vllm-omni#1652) - [UX] Add progress bar for diffusion models - Changes: - New feature: [UX] Add progress bar for diffusion models ### vllm-omni-perf - Source: [PR #1652](vllm-project/vllm-omni#1652) - [UX] Add progress bar for diffusion models - Changes: - New feature: [UX] Add progress bar for diffusion models ### vllm-omni-distributed - Source: [PR #1651](vllm-project/vllm-omni#1651) - docs: Announce vllm-omni-skills community project ### vllm-omni-quantization - Source: [PR #1651](vllm-project/vllm-omni#1651) - docs: Announce vllm-omni-skills community project ### vllm-omni-perf - Source: [PR #1651](vllm-project/vllm-omni#1651) - docs: Announce vllm-omni-skills community project ### vllm-omni-contrib - Source: [PR #1649](vllm-project/vllm-omni#1649) - [Misc] update wechat ### vllm-omni-perf - Source: [PR #1642](vllm-project/vllm-omni#1642) - [chore] add _repeated_blocks for regional compilation support - Changes: - New feature: [chore] add _repeated_blocks for regional compilation support ### vllm-omni-api - Source: [PR #1641](vllm-project/vllm-omni#1641) - [Bugfix] Add TTS request validation to prevent engine crashes - Changes: - New feature: [Bugfix] Add TTS request validation to prevent engine crashes ### vllm-omni-cicd - Source: [PR #1641](vllm-project/vllm-omni#1641) - [Bugfix] Add TTS request validation to prevent engine crashes - Changes: - New feature: [Bugfix] Add TTS request validation to prevent engine crashes ### vllm-omni-image-gen - Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer - Changes: - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer - Additions: - text-to-image - Text-to-Image - Flux ### vllm-omni-quantization - Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer - Changes: - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer - Additions: - FP8 support or improvements ### vllm-omni-contrib - Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer - Changes: - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer ### vllm-omni-perf - Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer - Changes: - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer ### vllm-omni-contrib - Source: [PR #1631](vllm-project/vllm-omni#1631) - [BugFix] Fix LongCat Sequence Parallelism / Small Cleanup - Changes: - Bug fix: [BugFix] Fix LongCat Sequence Parallelism / Small Cleanup ### vllm-omni-cicd - Source: [PR #1628](vllm-project/vllm-omni#1628) - [Test][Qwen3-Omni]Modify Qwen3-Omni benchmark test cases ### vllm-omni-perf - Source: [PR #1628](vllm-project/vllm-omni#1628) - [Test][Qwen3-Omni]Modify Qwen3-Omni benchmark test cases ### vllm-omni-perf - Source: [PR #1619](vllm-project/vllm-omni#1619) - [Bugfix] Fix Qwen3-TTS code predictor crash due to missing vLLM config context - Changes: - Bug fix: [Bugfix] Fix Qwen3-TTS code predictor crash due to missing vLLM config context ### vllm-omni-perf - Source: [PR #1617](vllm-project/vllm-omni#1617) - [Refactor][Perf] Qwen3-TTS: re-prefill Code Predictor with torch.compile + enable Code2Wav decoder CUDA Graph - Changes: - Performance improvement: [Refactor][Perf] Qwen3-TTS: re-prefill Code Predictor with torch.compile + enable Code2Wav decoder CUDA Graph ### vllm-omni-contrib - Source: [PR #1615](vllm-project/vllm-omni#1615) - [Doc] Fix links in the configuration doc - Changes: - Bug fix: [Doc] Fix links in the configuration doc ### vllm-omni-audio-tts - Source: [PR #1614](vllm-project/vllm-omni#1614) - perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor - Changes: - Performance improvement: perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor ### vllm-omni-perf - Source: [PR #1614](vllm-project/vllm-omni#1614) - perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor - Changes: - Performance improvement: perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor ### vllm-omni-image-gen - Source: [PR #1609](vllm-project/vllm-omni#1609) - [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation - Changes: - Bug fix: [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation - Additions: - GLM-Image - GLM-Image - GLM-Image - GLM-Image - GLM-Image - GLM-Image - GLM-Image - GLM-Image ### vllm-omni-api - Source: [PR #1609](vllm-project/vllm-omni#1609) - [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation - Changes: - Bug fix: [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation ### vllm-omni-perf - Source: [PR #1609](vllm-project/vllm-omni#1609) - [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation - Changes: - Bug fix: [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation ### vllm-omni-contrib - Source: [PR #1604](vllm-project/vllm-omni#1604) - [Model]: support Helios from ByteDance ### vllm-omni-perf - Source: [PR #1604](vllm-project/vllm-omni#1604) - [Model]: support Helios from ByteDance ### vllm-omni-serving - Source: [PR #1602](vllm-project/vllm-omni#1602) - [Bugfix] fix kernel error for qwen3-omni - Changes: - Bug fix: [Bugfix] fix kernel error for qwen3-omni ### vllm-omni-distributed - Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0 - Changes: - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0 ### vllm-omni-image-gen - Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0 - Changes: - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0 - Additions: - HunyuanImage3 - HunyuanImage3Pipeline - HunyuanImage3 - HunyuanImage-3 - HunyuanImage-3 - HunyuanImage-3 - HunyuanImage3Pipeline - HunyuanImage3Pipeline - HunyuanImage3Pipeline - HunyuanImage3Pipeline - HunyuanImage3Pipeline - HunyuanImage3Pipeline - HunyuanImage3Pipeline - HunyuanImage3Pipeline - HunyuanImage-3 ### vllm-omni-quantization - Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0 - Changes: - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0 ### vllm-omni-perf - Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0 - Changes: - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0 ### vllm-omni-audio-tts - Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase - Changes: - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase ### vllm-omni-api - Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase - Changes: - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase ### vllm-omni-cicd - Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase - Changes: - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase ### vllm-omni-contrib - Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase - Changes: - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase ### vllm-omni-api - Source: [PR #1579](vllm-project/vllm-omni#1579) - [1/N][Refactor] Clean up dead code in output processor ### vllm-omni-serving - Source: [PR #1579](vllm-project/vllm-omni#1579) - [1/N][Refactor] Clean up dead code in output processor ### vllm-omni-distributed - Source: [PR #1578](vllm-project/vllm-omni#1578) - [Feature][Bagel] Add CFG parallel mode - Changes: - New feature: [Feature][Bagel] Add CFG parallel mode ### vllm-omni-cicd - Source: [PR #1578](vllm-project/vllm-omni#1578) - [Feature][Bagel] Add CFG parallel mode - Changes: - New feature: [Feature][Bagel] Add CFG parallel mode ### vllm-omni-perf - Source: [PR #1578](vllm-project/vllm-omni#1578) - [Feature][Bagel] Add CFG parallel mode - Changes: - New feature: [Feature][Bagel] Add CFG parallel mode ### vllm-omni-contrib - Source: [PR #1576](vllm-project/vllm-omni#1576) - 0.16.0 release ### vllm-omni-audio-tts - Source: [PR #1570](vllm-project/vllm-omni#1570) - [bugfix] Fix unexpected argument 'is_finished' in function llm2code2wav_async_chunk of mimo-audio - Changes: - Bug fix: [bugfix] Fix unexpected argument 'is_finished' in function llm2code2wav_async_chunk of mimo-audio ### vllm-omni-api - Source: [PR #1566](vllm-project/vllm-omni#1566) - [Bugfix] Import InputPreprocessor into Renderer - Changes: - Bug fix: [Bugfix] Import InputPreprocessor into Renderer ### vllm-omni-distributed - Source: [PR #1539](vllm-project/vllm-omni#1539) - [Debug] Enable curl retry aligned with openai ### vllm-omni-quantization - Source: [PR #1539](vllm-project/vllm-omni#1539) - [Debug] Enable curl retry aligned with openai ### vllm-omni-perf - Source: [PR #1539](vllm-project/vllm-omni#1539) - [Debug] Enable curl retry aligned with openai ### vllm-omni-image-gen - Source: [PR #1537](vllm-project/vllm-omni#1537) - [NPU] [Features] [Bugfix] Support mindiesd adaln - Changes: - New feature: [NPU] [Features] [Bugfix] Support mindiesd adaln - Additions: - mindiesd - mindiesd - Qwen-Image-Edit-2509 - mindiesd - mindiesd - mindiesd - mindiesd ### vllm-omni-perf - Source: [PR #1537](vllm-project/vllm-omni#1537) - [NPU] [Features] [Bugfix] Support mindiesd adaln - Changes: - New feature: [NPU] [Features] [Bugfix] Support mindiesd adaln ### vllm-omni-serving - Source: [PR #1536](vllm-project/vllm-omni#1536) - [Bugfix] Fix transformers 5.x compat issues in online TTS serving - Changes: - Bug fix: [Bugfix] Fix transformers 5.x compat issues in online TTS serving ### vllm-omni-perf - Source: [PR #1536](vllm-project/vllm-omni#1536) - [Bugfix] Fix transformers 5.x compat issues in online TTS serving - Changes: - Bug fix: [Bugfix] Fix transformers 5.x compat issues in online TTS serving
Signed-off-by: Ding Zuhao <e1583181@u.nus.edu> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: 汪志鹏 <wangzhipeng628@gmail.com> Signed-off-by: lishunyang <lishunyang12@163.com>
Description
all_gather+broadcast, achieving 2.45x speedup on CFG-active denoising steps compared to the single-GPU batched modecfg_parallel_sizeinDiffusionParallelConfig:1= batched (single GPU, existing behavior unchanged),3= parallel (3 GPUs)cfg-parallel-sizeCLI argument in the Bagel offline inference example end2end.pyChanges
Core:
bagel_transformer.py(+342/-32)_combine_cfg()as a reusable static method supporting 3 renorm modes (global,channel,text_channel)_generate_image_parallel()for multi-GPU denoising loop with rank-based branch computation_forward_flow_single_branch()for simplified single-branch forward passesget_classifier_free_guidance_world_size() > 1cfg_parallel_sizevscfg_img_scaleconsistencyConfig:
data.pycfg_parallel_sizevalidation to accept[1, 2, 3]Bug fix:
diffusion_engine.pycfg_text_scale=1.0andcfg_img_scale=1.0indummy_runto prevent CFG parallel validation errors during engine warmup.BTW: If other models use cfg_text_scale / cfg_img_scale in the future with different semantics (e.g., 1.0 does not mean "disabled"), there may be problems. However due to the mathematical principles of cfg, 1.0 always mean "disabled"
Tests:
test_combine_cfg.py(+314)Bagel._combine_cfgcovering all renorm types, CFG scale combinations, and edge cases (CPU-only, no GPU required)Example:
end2end.py--cfg-parallel-sizeCLI argumentsparallel_configtoOmniDiffusionPurpose
Add CFG parallel mode to accelerate the bagel inference
Test plan
python -m pytest tests/diffusion/models/bagel/test_combine_cfg.py -v(13/13 pass)cfg_parallel_size=1produces same output as upstreamcfg_parallel_size=3produces identical images to batched mode(stage 1 only)cfg_parallel_size=2produces identical images to batched mode(stage 1 only)Test Env
cuda13.1.1, A6000(48GB) * 3, stage 1 only gpu_memory_utilization = 0.85
Test Command
python end2end.py --modality img2img --image-path puppy.jpg --prompts "a dog wearing sunglasses" --cfg-text-scale 3.0 --cfg-img-scale 1.2 --steps 60 --cfg-parallel-size 1python end2end.py --modality img2img --image-path puppy.jpg --prompts "a dog wearing sunglasses" --cfg-text-scale 3.0 --cfg-img-scale 1.2 --steps 60 --cfg-parallel-size 3python end2end.py --modality text2img --prompts "a dog wearing sunglasses" --cfg-text-scale 3.0 --steps 60 --cfg-parallel-size 2python end2end.py --modality text2img --prompts "a dog wearing sunglasses" --cfg-text-scale 3.0 --steps 60 --cfg-parallel-size 3Test Result
Image quality
Image quality is the same.

Inference speed
Summary
Speedup: ~2.5x on CFG steps, ~0.4s/step.
text2img
Raise Value Error
Unit Test
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)