diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index a8035a3fdcb..db74dfc6b3d 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -297,7 +297,7 @@ def parse_args() -> argparse.Namespace: "--cfg-parallel-size", type=int, default=1, - choices=[1, 2], + choices=[1, 2, 3], help="Number of GPUs used for classifier free guidance parallel size.", ) parser.add_argument( diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 927b0f0b087..1d6a6591d83 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -126,7 +126,7 @@ def parse_args() -> argparse.Namespace: "--cfg-parallel-size", type=int, default=1, - choices=[1, 2], + choices=[1, 2, 3], help="Number of GPUs used for classifier free guidance parallel size.", ) parser.add_argument( diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index 2d370aea19c..fe79ee1ad15 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -29,6 +29,11 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.omnigen2.omnigen2_transformer import ( @@ -1170,72 +1175,178 @@ def processing( ) self._num_timesteps = len(timesteps) - for i, t in enumerate(timesteps): - model_pred = self.predict( - t=t, + cfg_world_size = get_classifier_free_guidance_world_size() + use_cfg_img = self.image_guidance_scale > 1.0 + cfg_parallel_ready = ( + self.text_guidance_scale > 1.0 + and cfg_world_size > 1 + # image guidance needs a 3rd rank for the ref branch; fall back to serial if not available + and (not use_cfg_img or cfg_world_size >= 3) + ) + if cfg_parallel_ready: + latents = self._processing_parallel( latents=latents, - prompt_embeds=prompt_embeds, freqs_cis=freqs_cis, + prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask, - ref_image_hidden_states=ref_latents, - ) - text_guidance_scale = ( - self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + ref_latents=ref_latents, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask=negative_prompt_attention_mask, + timesteps=timesteps, + dtype=dtype, + step_func=step_func, ) - image_guidance_scale = ( - self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 - ) - - if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: - model_pred_ref = self.predict( + else: + for i, t in enumerate(timesteps): + model_pred = self.predict( t=t, latents=latents, - prompt_embeds=negative_prompt_embeds, + prompt_embeds=prompt_embeds, freqs_cis=freqs_cis, - prompt_attention_mask=negative_prompt_attention_mask, + prompt_attention_mask=prompt_attention_mask, ref_image_hidden_states=ref_latents, ) + text_guidance_scale = ( + self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + ) + image_guidance_scale = ( + self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + ) + + if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + model_pred_ref = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + + model_pred = ( + model_pred_uncond + + image_guidance_scale * (model_pred_ref - model_pred_uncond) + + text_guidance_scale * (model_pred - model_pred_ref) + ) + elif text_guidance_scale > 1.0: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def _processing_parallel( + self, + latents, + freqs_cis, + prompt_embeds, + prompt_attention_mask, + ref_latents, + negative_prompt_embeds, + negative_prompt_attention_mask, + timesteps, + dtype, + step_func=None, + ) -> torch.Tensor: + """CFG parallel denoising loop: each rank computes one CFG branch, returns latents. + + Rank 0: cond branch (prompt_embeds, ref_latents) + Rank 1: uncond branch (negative_prompt_embeds, None) + Rank 2: ref branch (negative_prompt_embeds, ref_latents) + """ + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + use_cfg_img = self.image_guidance_scale > 1.0 + + latents = latents.contiguous() + cfg_group.broadcast(latents, src=0) + + if cfg_rank == 0: + branch_prompt_embeds = prompt_embeds + branch_attention_mask = prompt_attention_mask + branch_ref_latents = ref_latents + elif cfg_rank == 1: + branch_prompt_embeds = negative_prompt_embeds + branch_attention_mask = negative_prompt_attention_mask + branch_ref_latents = None + else: + branch_prompt_embeds = negative_prompt_embeds + branch_attention_mask = negative_prompt_attention_mask + branch_ref_latents = ref_latents + + for i, t in enumerate(timesteps): + in_cfg_range = self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] + use_cfg_img_this_step = in_cfg_range and use_cfg_img - model_pred_uncond = self.predict( + if in_cfg_range: + local_pred = self.predict( t=t, latents=latents, - prompt_embeds=negative_prompt_embeds, + prompt_embeds=branch_prompt_embeds, freqs_cis=freqs_cis, - prompt_attention_mask=negative_prompt_attention_mask, - ref_image_hidden_states=None, + prompt_attention_mask=branch_attention_mask, + ref_image_hidden_states=branch_ref_latents, ) - - model_pred = ( - model_pred_uncond - + image_guidance_scale * (model_pred_ref - model_pred_uncond) - + text_guidance_scale * (model_pred - model_pred_ref) - ) - elif text_guidance_scale > 1.0: - model_pred_uncond = self.predict( + local_pred = local_pred.contiguous() + gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + model_pred, model_pred_uncond = gathered[0], gathered[1] + if use_cfg_img_this_step: + model_pred_ref = gathered[2] + model_pred = ( + model_pred_uncond + + self.image_guidance_scale * (model_pred_ref - model_pred_uncond) + + self.text_guidance_scale * (model_pred - model_pred_ref) + ) + else: + model_pred = model_pred_uncond + self.text_guidance_scale * (model_pred - model_pred_uncond) + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + else: + # Outside CFG interval: all ranks use cond branch, no comm + model_pred = self.predict( t=t, latents=latents, - prompt_embeds=negative_prompt_embeds, + prompt_embeds=prompt_embeds, freqs_cis=freqs_cis, - prompt_attention_mask=negative_prompt_attention_mask, - ref_image_hidden_states=None, + prompt_attention_mask=prompt_attention_mask, + ref_image_hidden_states=ref_latents, ) - model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) - - latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] latents = latents.to(dtype=dtype) if step_func is not None: step_func(i, self._num_timesteps) - latents = latents.to(dtype=dtype) - if self.vae.config.scaling_factor is not None: - latents = latents / self.vae.config.scaling_factor - if self.vae.config.shift_factor is not None: - latents = latents + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - - return image + return latents def predict( self,