-
Notifications
You must be signed in to change notification settings - Fork 1k
[Feature] Add CFG parallel to Omnigen2 #2074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
63e520e
76cc9c7
de46e33
ab30b94
a45ead1
9581777
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,11 @@ | |
| from vllm.model_executor.models.utils import AutoWeightsLoader | ||
|
|
||
| from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig | ||
| from vllm_omni.diffusion.distributed.parallel_state import ( | ||
| get_cfg_group, | ||
| get_classifier_free_guidance_rank, | ||
| get_classifier_free_guidance_world_size, | ||
| ) | ||
| from vllm_omni.diffusion.distributed.utils import get_local_device | ||
| from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader | ||
| from vllm_omni.diffusion.models.omnigen2.omnigen2_transformer import ( | ||
|
|
@@ -1170,72 +1175,178 @@ def processing( | |
| ) | ||
| self._num_timesteps = len(timesteps) | ||
|
|
||
| for i, t in enumerate(timesteps): | ||
| model_pred = self.predict( | ||
| t=t, | ||
| cfg_world_size = get_classifier_free_guidance_world_size() | ||
| use_cfg_img = self.image_guidance_scale > 1.0 | ||
| cfg_parallel_ready = ( | ||
| self.text_guidance_scale > 1.0 | ||
| and cfg_world_size > 1 | ||
| # image guidance needs a 3rd rank for the ref branch; fall back to serial if not available | ||
| and (not use_cfg_img or cfg_world_size >= 3) | ||
| ) | ||
| if cfg_parallel_ready: | ||
| latents = self._processing_parallel( | ||
| latents=latents, | ||
| prompt_embeds=prompt_embeds, | ||
| freqs_cis=freqs_cis, | ||
| prompt_embeds=prompt_embeds, | ||
| prompt_attention_mask=prompt_attention_mask, | ||
| ref_image_hidden_states=ref_latents, | ||
| ) | ||
| text_guidance_scale = ( | ||
| self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 | ||
| ref_latents=ref_latents, | ||
| negative_prompt_embeds=negative_prompt_embeds, | ||
| negative_prompt_attention_mask=negative_prompt_attention_mask, | ||
| timesteps=timesteps, | ||
| dtype=dtype, | ||
| step_func=step_func, | ||
| ) | ||
| image_guidance_scale = ( | ||
| self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 | ||
| ) | ||
|
|
||
| if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: | ||
| model_pred_ref = self.predict( | ||
| else: | ||
| for i, t in enumerate(timesteps): | ||
| model_pred = self.predict( | ||
| t=t, | ||
| latents=latents, | ||
| prompt_embeds=negative_prompt_embeds, | ||
| prompt_embeds=prompt_embeds, | ||
| freqs_cis=freqs_cis, | ||
| prompt_attention_mask=negative_prompt_attention_mask, | ||
| prompt_attention_mask=prompt_attention_mask, | ||
| ref_image_hidden_states=ref_latents, | ||
| ) | ||
| text_guidance_scale = ( | ||
| self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 | ||
| ) | ||
| image_guidance_scale = ( | ||
| self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 | ||
| ) | ||
|
|
||
| if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: | ||
| model_pred_ref = self.predict( | ||
| t=t, | ||
| latents=latents, | ||
| prompt_embeds=negative_prompt_embeds, | ||
| freqs_cis=freqs_cis, | ||
| prompt_attention_mask=negative_prompt_attention_mask, | ||
| ref_image_hidden_states=ref_latents, | ||
| ) | ||
|
|
||
| model_pred_uncond = self.predict( | ||
| t=t, | ||
| latents=latents, | ||
| prompt_embeds=negative_prompt_embeds, | ||
| freqs_cis=freqs_cis, | ||
| prompt_attention_mask=negative_prompt_attention_mask, | ||
| ref_image_hidden_states=None, | ||
| ) | ||
|
|
||
| model_pred = ( | ||
| model_pred_uncond | ||
| + image_guidance_scale * (model_pred_ref - model_pred_uncond) | ||
| + text_guidance_scale * (model_pred - model_pred_ref) | ||
| ) | ||
| elif text_guidance_scale > 1.0: | ||
| model_pred_uncond = self.predict( | ||
| t=t, | ||
| latents=latents, | ||
| prompt_embeds=negative_prompt_embeds, | ||
| freqs_cis=freqs_cis, | ||
| prompt_attention_mask=negative_prompt_attention_mask, | ||
| ref_image_hidden_states=None, | ||
| ) | ||
| model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) | ||
|
|
||
| latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] | ||
|
|
||
| latents = latents.to(dtype=dtype) | ||
|
|
||
| if step_func is not None: | ||
| step_func(i, self._num_timesteps) | ||
|
|
||
| latents = latents.to(dtype=dtype) | ||
| if self.vae.config.scaling_factor is not None: | ||
| latents = latents / self.vae.config.scaling_factor | ||
| if self.vae.config.shift_factor is not None: | ||
| latents = latents + self.vae.config.shift_factor | ||
| image = self.vae.decode(latents, return_dict=False)[0] | ||
|
|
||
| return image | ||
|
|
||
| def _processing_parallel( | ||
| self, | ||
| latents, | ||
| freqs_cis, | ||
| prompt_embeds, | ||
| prompt_attention_mask, | ||
| ref_latents, | ||
| negative_prompt_embeds, | ||
| negative_prompt_attention_mask, | ||
| timesteps, | ||
| dtype, | ||
| step_func=None, | ||
| ) -> torch.Tensor: | ||
| """CFG parallel denoising loop: each rank computes one CFG branch, returns latents. | ||
|
|
||
| Rank 0: cond branch (prompt_embeds, ref_latents) | ||
| Rank 1: uncond branch (negative_prompt_embeds, None) | ||
| Rank 2: ref branch (negative_prompt_embeds, ref_latents) | ||
| """ | ||
| cfg_group = get_cfg_group() | ||
| cfg_rank = get_classifier_free_guidance_rank() | ||
| use_cfg_img = self.image_guidance_scale > 1.0 | ||
|
zzhuoxin1508 marked this conversation as resolved.
|
||
|
|
||
| latents = latents.contiguous() | ||
| cfg_group.broadcast(latents, src=0) | ||
|
|
||
| if cfg_rank == 0: | ||
| branch_prompt_embeds = prompt_embeds | ||
| branch_attention_mask = prompt_attention_mask | ||
| branch_ref_latents = ref_latents | ||
| elif cfg_rank == 1: | ||
| branch_prompt_embeds = negative_prompt_embeds | ||
| branch_attention_mask = negative_prompt_attention_mask | ||
| branch_ref_latents = None | ||
| else: | ||
|
zzhuoxin1508 marked this conversation as resolved.
|
||
| branch_prompt_embeds = negative_prompt_embeds | ||
| branch_attention_mask = negative_prompt_attention_mask | ||
| branch_ref_latents = ref_latents | ||
|
|
||
| for i, t in enumerate(timesteps): | ||
| in_cfg_range = self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] | ||
| use_cfg_img_this_step = in_cfg_range and use_cfg_img | ||
|
|
||
| model_pred_uncond = self.predict( | ||
| if in_cfg_range: | ||
| local_pred = self.predict( | ||
| t=t, | ||
| latents=latents, | ||
| prompt_embeds=negative_prompt_embeds, | ||
| prompt_embeds=branch_prompt_embeds, | ||
| freqs_cis=freqs_cis, | ||
| prompt_attention_mask=negative_prompt_attention_mask, | ||
| ref_image_hidden_states=None, | ||
| prompt_attention_mask=branch_attention_mask, | ||
| ref_image_hidden_states=branch_ref_latents, | ||
| ) | ||
|
|
||
| model_pred = ( | ||
| model_pred_uncond | ||
| + image_guidance_scale * (model_pred_ref - model_pred_uncond) | ||
| + text_guidance_scale * (model_pred - model_pred_ref) | ||
| ) | ||
| elif text_guidance_scale > 1.0: | ||
| model_pred_uncond = self.predict( | ||
| local_pred = local_pred.contiguous() | ||
| gathered = cfg_group.all_gather(local_pred, separate_tensors=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| model_pred, model_pred_uncond = gathered[0], gathered[1] | ||
| if use_cfg_img_this_step: | ||
| model_pred_ref = gathered[2] | ||
| model_pred = ( | ||
| model_pred_uncond | ||
| + self.image_guidance_scale * (model_pred_ref - model_pred_uncond) | ||
| + self.text_guidance_scale * (model_pred - model_pred_ref) | ||
| ) | ||
| else: | ||
| model_pred = model_pred_uncond + self.text_guidance_scale * (model_pred - model_pred_uncond) | ||
| latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] | ||
| else: | ||
| # Outside CFG interval: all ranks use cond branch, no comm | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Outside the CFG range every rank computes the same cond prediction independently — wasted FLOPs on ranks 1+. Consider having only rank 0 run
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks,but data is the same on all cards, if we only run it on Rank 0, we’d just be adding an extra broadcast step. it wouldn't really save any time. |
||
| model_pred = self.predict( | ||
| t=t, | ||
| latents=latents, | ||
| prompt_embeds=negative_prompt_embeds, | ||
| prompt_embeds=prompt_embeds, | ||
| freqs_cis=freqs_cis, | ||
| prompt_attention_mask=negative_prompt_attention_mask, | ||
| ref_image_hidden_states=None, | ||
| prompt_attention_mask=prompt_attention_mask, | ||
| ref_image_hidden_states=ref_latents, | ||
| ) | ||
| model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) | ||
|
|
||
| latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] | ||
| latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] | ||
|
|
||
| latents = latents.to(dtype=dtype) | ||
|
|
||
| if step_func is not None: | ||
| step_func(i, self._num_timesteps) | ||
|
|
||
| latents = latents.to(dtype=dtype) | ||
| if self.vae.config.scaling_factor is not None: | ||
| latents = latents / self.vae.config.scaling_factor | ||
| if self.vae.config.shift_factor is not None: | ||
| latents = latents + self.vae.config.shift_factor | ||
| image = self.vae.decode(latents, return_dict=False)[0] | ||
|
|
||
| return image | ||
| return latents | ||
|
|
||
| def predict( | ||
| self, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.