[Feature] Add CFG parallel to Omnigen2 #2074
Conversation
Signed-off-by: zhou zhuoxin <zhouzhuoxin1508@outlook.com>
Co-authored-by: princepride <wangzhipeng628@gmail.com> Co-authored-by: Ding Zuhao <e1583181@u.nus.edu> Signed-off-by: zhou zhuoxin <zhouzhuoxin1508@outlook.com>
23c4ec5 to
76cc9c7
Compare
| ) -> torch.Tensor: | ||
| """CFG parallel denoising loop: each rank computes one CFG branch, returns latents. | ||
|
|
||
| Rank 0: cond branch (prompt_embeds, ref_latents) |
lishunyang12
left a comment
There was a problem hiding this comment.
Left a few comments on the parallel path.
| ) | ||
| elif text_guidance_scale > 1.0: | ||
| model_pred_uncond = self.predict( | ||
| gathered = cfg_group.all_gather(local_pred, separate_tensors=True) |
There was a problem hiding this comment.
all_gather_into_tensor requires contiguous input. local_pred coming out of self.predict() may not be contiguous depending on the transformer output layout. Add local_pred = local_pred.contiguous() before the all-gather, same way you already do for latents at line 1291.
|
|
||
| for i, t in enumerate(timesteps): | ||
| in_cfg_range = self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] | ||
| use_cfg_this_step = in_cfg_range and self.text_guidance_scale > 1.0 |
There was a problem hiding this comment.
self.text_guidance_scale > 1.0 is always true inside _processing_parallel (the caller checks it in cfg_parallel_ready). This makes use_cfg_this_step equivalent to just in_cfg_range. Not a bug, but it's confusing — consider simplifying to use_cfg_this_step = in_cfg_range.
| model_pred_uncond = self.predict( | ||
| gathered = cfg_group.all_gather(local_pred, separate_tensors=True) | ||
| model_pred, model_pred_uncond = gathered[0], gathered[1] | ||
| if use_cfg_img_this_step and len(gathered) > 2: |
There was a problem hiding this comment.
Nit: len(gathered) > 2 is always true when use_cfg_img_this_step is true, since the caller guarantees cfg_world_size >= 3 whenever use_cfg_img. The double-check reads like there's a case where it could be false. Drop it or add a comment explaining it's a defensive check.
| model_pred = model_pred_uncond + self.text_guidance_scale * (model_pred - model_pred_uncond) | ||
| latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] | ||
| else: | ||
| # Outside CFG interval: all ranks use cond branch, no comm |
There was a problem hiding this comment.
Outside the CFG range every rank computes the same cond prediction independently — wasted FLOPs on ranks 1+. Consider having only rank 0 run predict and broadcasting model_pred, like you already do for the initial latents sync.
There was a problem hiding this comment.
Thanks,but data is the same on all cards, if we only run it on Rank 0, we’d just be adding an extra broadcast step. it wouldn't really save any time.
Signed-off-by: zhou zhuoxin <zhouzhuoxin1508@outlook.com>
1784777 to
de46e33
Compare
Signed-off-by: zhou zhuoxin <zhouzhuoxin1508@outlook.com>
|
@princepride Could you please review this |
princepride
left a comment
There was a problem hiding this comment.
I noticed that this model also need set cfg_p=3, right now we have a pr about refactor the cfg_p: #2063, can you cooperate with him and try to use our own implementation of cfg_p.
Thanks, I'll look into it |
There are similar problems in integrating existing cfg_p (even after #2063) with Omnigen2 and DreamID-Omni: cfg_p now only supports defining no more than 2 branches, while Omnigen2 got 3 branches and DreamID-Omni got 4 branches. Further refractor is needed to support multi-branches (>2) diffusion models. It will be good for us to have a discussion about a more general cfg parallel api, and the plan to refractor DreamID-Omni and Omnigen2 on it. |
|
A recent PR changed the diffusion features docs strucure. Pls PTAL #1928. |
Signed-off-by: zhou zhuoxin <zhouzhuoxin1508@outlook.com>
|
Missing e2e test for CFG parallelism. Please add a test that covers Documentation incomplete:
Can you also report the peak VRAM usage in your PR body? |
Thanks for the review! |
|
done in #2423 |


PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
cfg_parallel_size=2for text-to-image,cfg_parallel_size=3for image editing with ref image)cfg_world_size < 3)Test Plan
text2image
python text_to_image.py
--model "OmniGen2/OmniGen2"
--prompt "A classroom with bright lighting and wooden desks."
--negative-prompt "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"
--num-inference-steps 50
--seed 0
--guidance-scale 5.0
--cfg-parallel-size 2
--output /workspace/outputs/image_t2icfg3.png
image2image
-python image_edit.py
--image /workspace/image1.png
--model "OmniGen2/OmniGen2"
--prompt "Change the background to classroom."
--negative-prompt "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"
--num-inference-steps 50
--seed 0
--guidance-scale 5.0
--guidance-scale-2 2.0
--cfg-parallel-size 3
--output /workspace/outputs/image_edit.png
Test Result
Text-to-Image (1024×1024,)
Output images are visually identical across sequential and parallel modes.
Image Editing (inputsize=1696×2528)
Output images are visually identical across sequential and parallel modes.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)