Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 116 additions & 109 deletions vllm_omni/diffusion/models/helios/pipeline_helios.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.helios.helios_transformer import HeliosTransformer3DModel
from vllm_omni.diffusion.models.helios.scheduling_helios import HeliosScheduler
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.platforms import current_omni_platform

Expand Down Expand Up @@ -140,7 +141,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest:
return pre_process_func


class HeliosPipeline(nn.Module, CFGParallelMixin):
class HeliosPipeline(nn.Module, CFGParallelMixin, ProgressBarMixin):
"""Helios text-to-video / image-to-video / video-to-video pipeline for vllm-omni.

Supports T2V, I2V (with image input), and V2V (with video input).
Expand Down Expand Up @@ -662,67 +663,70 @@ def _stage1_sample(
batch_size = latents.shape[0]
do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None

for i, t in enumerate(timesteps):
self._current_timestep = t
timestep = t.expand(batch_size)

transformer_kwargs = {
"hidden_states": latents.to(transformer_dtype),
"timestep": timestep,
"indices_hidden_states": indices_hidden_states,
"indices_latents_history_short": indices_latents_history_short,
"indices_latents_history_mid": indices_latents_history_mid,
"indices_latents_history_long": indices_latents_history_long,
"latents_history_short": latents_history_short.to(transformer_dtype),
"latents_history_mid": latents_history_mid.to(transformer_dtype),
"latents_history_long": latents_history_long.to(transformer_dtype),
"attention_kwargs": attention_kwargs,
"return_dict": False,
}

if use_cfg_zero_star and do_true_cfg:
noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds,
**transformer_kwargs,
)[0]

noise_uncond = self.transformer(
encoder_hidden_states=negative_prompt_embeds,
**transformer_kwargs,
)[0]

positive_flat = noise_pred.view(batch_size, -1)
negative_flat = noise_uncond.view(batch_size, -1)
alpha_cfg = optimized_scale(positive_flat, negative_flat)
alpha_cfg = alpha_cfg.view(batch_size, *([1] * (len(noise_pred.shape) - 1)))
alpha_cfg = alpha_cfg.to(noise_pred.dtype)

if (i <= zero_steps) and use_zero_init:
noise_pred = noise_pred * 0.0
else:
noise_pred = noise_uncond * alpha_cfg + guidance_scale * (noise_pred - noise_uncond * alpha_cfg)
else:
positive_kwargs = {
"encoder_hidden_states": prompt_embeds,
**transformer_kwargs,
with self.progress_bar(total=len(timesteps)) as pbar:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only add with self.progress_bar(total=len(timesteps)) as pbar: for this changing block.

for i, t in enumerate(timesteps):
self._current_timestep = t
timestep = t.expand(batch_size)

transformer_kwargs = {
"hidden_states": latents.to(transformer_dtype),
"timestep": timestep,
"indices_hidden_states": indices_hidden_states,
"indices_latents_history_short": indices_latents_history_short,
"indices_latents_history_mid": indices_latents_history_mid,
"indices_latents_history_long": indices_latents_history_long,
"latents_history_short": latents_history_short.to(transformer_dtype),
"latents_history_mid": latents_history_mid.to(transformer_dtype),
"latents_history_long": latents_history_long.to(transformer_dtype),
"attention_kwargs": attention_kwargs,
"return_dict": False,
}
if do_true_cfg:
negative_kwargs = {
"encoder_hidden_states": negative_prompt_embeds,

if use_cfg_zero_star and do_true_cfg:
noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds,
**transformer_kwargs,
}
)[0]

noise_uncond = self.transformer(
encoder_hidden_states=negative_prompt_embeds,
**transformer_kwargs,
)[0]

positive_flat = noise_pred.view(batch_size, -1)
negative_flat = noise_uncond.view(batch_size, -1)
alpha_cfg = optimized_scale(positive_flat, negative_flat)
alpha_cfg = alpha_cfg.view(batch_size, *([1] * (len(noise_pred.shape) - 1)))
alpha_cfg = alpha_cfg.to(noise_pred.dtype)

if (i <= zero_steps) and use_zero_init:
noise_pred = noise_pred * 0.0
else:
noise_pred = noise_uncond * alpha_cfg + guidance_scale * (noise_pred - noise_uncond * alpha_cfg)
else:
negative_kwargs = None

noise_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=do_true_cfg,
true_cfg_scale=guidance_scale,
positive_kwargs=positive_kwargs,
negative_kwargs=negative_kwargs,
cfg_normalize=False,
)
positive_kwargs = {
"encoder_hidden_states": prompt_embeds,
**transformer_kwargs,
}
if do_true_cfg:
negative_kwargs = {
"encoder_hidden_states": negative_prompt_embeds,
**transformer_kwargs,
}
else:
negative_kwargs = None

noise_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=do_true_cfg,
true_cfg_scale=guidance_scale,
positive_kwargs=positive_kwargs,
negative_kwargs=negative_kwargs,
cfg_normalize=False,
)

latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)

pbar.update()

return latents

Expand Down Expand Up @@ -807,62 +811,65 @@ def _stage2_sample(
if self.is_distilled and start_point_list is not None:
start_point_list.append(latents)

for idx, t in enumerate(timesteps):
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(torch.int64)

transformer_kwargs = {
"hidden_states": latents.to(transformer_dtype),
"timestep": timestep,
"indices_hidden_states": indices_hidden_states,
"indices_latents_history_short": indices_latents_history_short,
"indices_latents_history_mid": indices_latents_history_mid,
"indices_latents_history_long": indices_latents_history_long,
"latents_history_short": latents_history_short.to(transformer_dtype),
"latents_history_mid": latents_history_mid.to(transformer_dtype),
"latents_history_long": latents_history_long.to(transformer_dtype),
"attention_kwargs": attention_kwargs,
"return_dict": False,
}

noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds,
**transformer_kwargs,
)[0]
with self.progress_bar(total=len(timesteps)) as pbar:
for idx, t in enumerate(timesteps):
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(torch.int64)

transformer_kwargs = {
"hidden_states": latents.to(transformer_dtype),
"timestep": timestep,
"indices_hidden_states": indices_hidden_states,
"indices_latents_history_short": indices_latents_history_short,
"indices_latents_history_mid": indices_latents_history_mid,
"indices_latents_history_long": indices_latents_history_long,
"latents_history_short": latents_history_short.to(transformer_dtype),
"latents_history_mid": latents_history_mid.to(transformer_dtype),
"latents_history_long": latents_history_long.to(transformer_dtype),
"attention_kwargs": attention_kwargs,
"return_dict": False,
}

if do_true_cfg:
noise_uncond = self.transformer(
encoder_hidden_states=negative_prompt_embeds,
noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds,
**transformer_kwargs,
)[0]

if use_cfg_zero_star:
positive_flat = noise_pred.view(batch_size, -1)
negative_flat = noise_uncond.view(batch_size, -1)
alpha_cfg = optimized_scale(positive_flat, negative_flat)
alpha_cfg = alpha_cfg.view(batch_size, *([1] * (len(noise_pred.shape) - 1)))
alpha_cfg = alpha_cfg.to(noise_pred.dtype)

if (i_s == 0 and idx <= zero_steps) and use_zero_init:
noise_pred = noise_pred * 0.0
if do_true_cfg:
noise_uncond = self.transformer(
encoder_hidden_states=negative_prompt_embeds,
**transformer_kwargs,
)[0]

if use_cfg_zero_star:
positive_flat = noise_pred.view(batch_size, -1)
negative_flat = noise_uncond.view(batch_size, -1)
alpha_cfg = optimized_scale(positive_flat, negative_flat)
alpha_cfg = alpha_cfg.view(batch_size, *([1] * (len(noise_pred.shape) - 1)))
alpha_cfg = alpha_cfg.to(noise_pred.dtype)

if (i_s == 0 and idx <= zero_steps) and use_zero_init:
noise_pred = noise_pred * 0.0
else:
noise_pred = noise_uncond * alpha_cfg + guidance_scale * (
noise_pred - noise_uncond * alpha_cfg
)
else:
noise_pred = noise_uncond * alpha_cfg + guidance_scale * (
noise_pred - noise_uncond * alpha_cfg
)
else:
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

latents = self.scheduler.step(
noise_pred,
t,
latents,
return_dict=False,
cur_sampling_step=idx,
dmd_noisy_tensor=start_point_list[i_s] if start_point_list is not None else None,
dmd_sigmas=self.scheduler.sigmas,
dmd_timesteps=self.scheduler.timesteps,
all_timesteps=timesteps,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

latents = self.scheduler.step(
noise_pred,
t,
latents,
return_dict=False,
cur_sampling_step=idx,
dmd_noisy_tensor=start_point_list[i_s] if start_point_list is not None else None,
dmd_sigmas=self.scheduler.sigmas,
dmd_timesteps=self.scheduler.timesteps,
all_timesteps=timesteps,
)[0]

pbar.update()

return latents

Expand Down
53 changes: 53 additions & 0 deletions vllm_omni/diffusion/models/progress_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Progress bar mixin for diffusion pipelines.

Provides a diffusers-compatible progress_bar() method that wraps tqdm,
automatically disabling output on non-zero ranks in distributed settings.
"""

import torch
from tqdm.auto import tqdm


class ProgressBarMixin:
"""Mixin that provides a progress bar for denoising loops.

Usage in pipeline:
class MyPipeline(nn.Module, CFGParallelMixin, ProgressBarMixin):
def diffuse(self, ...):
with self.progress_bar(total=num_steps) as pbar:
for i, t in enumerate(timesteps):
...
pbar.update()
"""

def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)

config = dict(self._progress_bar_config)
# Only show progress bar on rank 0 in distributed settings
if "disable" not in config:
config["disable"] = not _is_rank_zero()

if iterable is not None:
return tqdm(iterable, **config)
elif total is not None:
return tqdm(total=total, **config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")

def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs


def _is_rank_zero() -> bool:
if not torch.distributed.is_initialized():
return True
return torch.distributed.get_rank() == 0
Loading