Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def parse_args() -> argparse.Namespace:
"--cfg-parallel-size",
type=int,
default=1,
choices=[1, 2],
choices=[1, 2, 3],
help="Number of GPUs used for classifier free guidance parallel size.",
)
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def parse_args() -> argparse.Namespace:
"--cfg-parallel-size",
type=int,
default=1,
choices=[1, 2],
choices=[1, 2, 3],
help="Number of GPUs used for classifier free guidance parallel size.",
)
parser.add_argument(
Expand Down
195 changes: 153 additions & 42 deletions vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.parallel_state import (
get_cfg_group,
get_classifier_free_guidance_rank,
get_classifier_free_guidance_world_size,
)
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.omnigen2.omnigen2_transformer import (
Expand Down Expand Up @@ -1170,72 +1175,178 @@ def processing(
)
self._num_timesteps = len(timesteps)

for i, t in enumerate(timesteps):
model_pred = self.predict(
t=t,
cfg_world_size = get_classifier_free_guidance_world_size()
use_cfg_img = self.image_guidance_scale > 1.0
Comment thread
zzhuoxin1508 marked this conversation as resolved.
cfg_parallel_ready = (
self.text_guidance_scale > 1.0
and cfg_world_size > 1
# image guidance needs a 3rd rank for the ref branch; fall back to serial if not available
and (not use_cfg_img or cfg_world_size >= 3)
)
if cfg_parallel_ready:
latents = self._processing_parallel(
latents=latents,
prompt_embeds=prompt_embeds,
freqs_cis=freqs_cis,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
text_guidance_scale = (
self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
ref_latents=ref_latents,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
timesteps=timesteps,
dtype=dtype,
step_func=step_func,
)
image_guidance_scale = (
self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
)

if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
model_pred_ref = self.predict(
else:
for i, t in enumerate(timesteps):
model_pred = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
prompt_embeds=prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
prompt_attention_mask=prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
text_guidance_scale = (
self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
)
image_guidance_scale = (
self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
)

if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
model_pred_ref = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)

model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)

model_pred = (
model_pred_uncond
+ image_guidance_scale * (model_pred_ref - model_pred_uncond)
+ text_guidance_scale * (model_pred - model_pred_ref)
)
elif text_guidance_scale > 1.0:
model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)

latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]

latents = latents.to(dtype=dtype)

if step_func is not None:
step_func(i, self._num_timesteps)

latents = latents.to(dtype=dtype)
if self.vae.config.scaling_factor is not None:
latents = latents / self.vae.config.scaling_factor
if self.vae.config.shift_factor is not None:
latents = latents + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]

return image

def _processing_parallel(
self,
latents,
freqs_cis,
prompt_embeds,
prompt_attention_mask,
ref_latents,
negative_prompt_embeds,
negative_prompt_attention_mask,
timesteps,
dtype,
step_func=None,
) -> torch.Tensor:
"""CFG parallel denoising loop: each rank computes one CFG branch, returns latents.

Rank 0: cond branch (prompt_embeds, ref_latents)
Rank 1: uncond branch (negative_prompt_embeds, None)
Rank 2: ref branch (negative_prompt_embeds, ref_latents)
"""
cfg_group = get_cfg_group()
cfg_rank = get_classifier_free_guidance_rank()
use_cfg_img = self.image_guidance_scale > 1.0
Comment thread
zzhuoxin1508 marked this conversation as resolved.

latents = latents.contiguous()
cfg_group.broadcast(latents, src=0)

if cfg_rank == 0:
branch_prompt_embeds = prompt_embeds
branch_attention_mask = prompt_attention_mask
branch_ref_latents = ref_latents
elif cfg_rank == 1:
branch_prompt_embeds = negative_prompt_embeds
branch_attention_mask = negative_prompt_attention_mask
branch_ref_latents = None
else:
Comment thread
zzhuoxin1508 marked this conversation as resolved.
branch_prompt_embeds = negative_prompt_embeds
branch_attention_mask = negative_prompt_attention_mask
branch_ref_latents = ref_latents

for i, t in enumerate(timesteps):
in_cfg_range = self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1]
use_cfg_img_this_step = in_cfg_range and use_cfg_img

model_pred_uncond = self.predict(
if in_cfg_range:
local_pred = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
prompt_embeds=branch_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
prompt_attention_mask=branch_attention_mask,
ref_image_hidden_states=branch_ref_latents,
)

model_pred = (
model_pred_uncond
+ image_guidance_scale * (model_pred_ref - model_pred_uncond)
+ text_guidance_scale * (model_pred - model_pred_ref)
)
elif text_guidance_scale > 1.0:
model_pred_uncond = self.predict(
local_pred = local_pred.contiguous()
gathered = cfg_group.all_gather(local_pred, separate_tensors=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_gather_into_tensor requires contiguous input. local_pred coming out of self.predict() may not be contiguous depending on the transformer output layout. Add local_pred = local_pred.contiguous() before the all-gather, same way you already do for latents at line 1291.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

model_pred, model_pred_uncond = gathered[0], gathered[1]
if use_cfg_img_this_step:
model_pred_ref = gathered[2]
model_pred = (
model_pred_uncond
+ self.image_guidance_scale * (model_pred_ref - model_pred_uncond)
+ self.text_guidance_scale * (model_pred - model_pred_ref)
)
else:
model_pred = model_pred_uncond + self.text_guidance_scale * (model_pred - model_pred_uncond)
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
else:
# Outside CFG interval: all ranks use cond branch, no comm
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside the CFG range every rank computes the same cond prediction independently — wasted FLOPs on ranks 1+. Consider having only rank 0 run predict and broadcasting model_pred, like you already do for the initial latents sync.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks,but data is the same on all cards, if we only run it on Rank 0, we’d just be adding an extra broadcast step. it wouldn't really save any time.

model_pred = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
prompt_embeds=prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
prompt_attention_mask=prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)

latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]

latents = latents.to(dtype=dtype)

if step_func is not None:
step_func(i, self._num_timesteps)

latents = latents.to(dtype=dtype)
if self.vae.config.scaling_factor is not None:
latents = latents / self.vae.config.scaling_factor
if self.vae.config.shift_factor is not None:
latents = latents + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]

return image
return latents

def predict(
self,
Expand Down
Loading