From b27d792e765b73ff015fb4bf88ec59b986f0b0f2 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 5 Apr 2026 10:02:58 +0800 Subject: [PATCH 1/2] [Feat] add diffusion pipeline profiler support to FluxKontextPipeline et.al Signed-off-by: Lancer --- vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py | 7 ++++++- vllm_omni/diffusion/models/flux2/pipeline_flux2.py | 7 ++++++- .../models/hunyuan_video/pipeline_hunyuan_video_1_5.py | 7 ++++++- .../models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py | 7 ++++++- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py index 3232b436d60..b232c0a7369 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py @@ -31,6 +31,7 @@ ) from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.logger import init_logger @@ -67,7 +68,7 @@ def post_process_func(images: torch.Tensor) -> list[PIL.Image.Image]: return post_process_func -class FluxKontextPipeline(nn.Module, FluxPipelineMixin, SupportImageInput): +class FluxKontextPipeline(nn.Module, FluxPipelineMixin, SupportImageInput, DiffusionPipelineProfilerMixin): """FLUX.1-Kontext pipeline for image editing with text guidance.""" support_image_input = True @@ -148,6 +149,10 @@ def __init__( self._callback_tensor_inputs = ["latents", "prompt_embeds"] self.latent_channels = self.vae.config.latent_channels if hasattr(self.vae, "config") else 16 + self.setup_diffusion_pipeline_profiler( + enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler + ) + def _get_t5_prompt_embeds( self, prompt: str | list[str] = None, diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index c5bf9b77d9e..4f8011f2283 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -29,6 +29,7 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific @@ -331,7 +332,7 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator = raise AttributeError("Could not access latents of provided encoder_output") -class Flux2Pipeline(nn.Module, SupportImageInput): +class Flux2Pipeline(nn.Module, SupportImageInput, DiffusionPipelineProfilerMixin): """Flux2 pipeline for text-to-image generation.""" _callback_tensor_inputs = ["latents", "prompt_embeds"] @@ -389,6 +390,10 @@ def __init__( self._guidance_scale = None self._attention_kwargs = None self._num_timesteps = None + + self.setup_diffusion_pipeline_profiler( + enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler + ) self._current_timestep = None self._interrupt = False diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py index 0b68676e8dc..c2c32752cb6 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.hunyuan_video.hunyuan_video_15_transformer import HunyuanVideo15Transformer3DModel from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.platforms import current_omni_platform @@ -81,7 +82,7 @@ def post_process_func(video: torch.Tensor, output_type: str = "pil"): return post_process_func -class HunyuanVideo15Pipeline(nn.Module, CFGParallelMixin): +class HunyuanVideo15Pipeline(nn.Module, CFGParallelMixin, DiffusionPipelineProfilerMixin): def __init__( self, *, @@ -173,6 +174,10 @@ def __init__( self._num_timesteps = None self._current_timestep = None + self.setup_diffusion_pipeline_profiler( + enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler + ) + @property def guidance_scale(self): return self._guidance_scale diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py index d68c43125c5..022641394e2 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py @@ -39,6 +39,7 @@ ) from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.platforms import current_omni_platform @@ -98,7 +99,7 @@ def pre_process_func(req: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class HunyuanVideo15I2VPipeline(nn.Module, CFGParallelMixin, SupportImageInput): +class HunyuanVideo15I2VPipeline(nn.Module, CFGParallelMixin, SupportImageInput, DiffusionPipelineProfilerMixin): support_image_input = True color_format = "RGB" @@ -199,6 +200,10 @@ def __init__( self._num_timesteps = None self._current_timestep = None + self.setup_diffusion_pipeline_profiler( + enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler + ) + @property def guidance_scale(self): return self._guidance_scale From 1adbc36121d8d3366897a050733a29bddcb7b991 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 5 Apr 2026 14:42:31 +0800 Subject: [PATCH 2/2] upd Signed-off-by: Lancer --- .../models/flux/pipeline_flux_kontext.py | 98 +++++++++-------- .../diffusion/models/flux2/pipeline_flux2.py | 90 ++++++++-------- .../pipeline_hunyuan_video_1_5.py | 98 +++++++++-------- .../pipeline_hunyuan_video_1_5_i2v.py | 102 +++++++++--------- 4 files changed, 204 insertions(+), 184 deletions(-) diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py index b232c0a7369..c7574c1c854 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py @@ -31,6 +31,7 @@ ) from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs @@ -68,7 +69,9 @@ def post_process_func(images: torch.Tensor) -> list[PIL.Image.Image]: return post_process_func -class FluxKontextPipeline(nn.Module, FluxPipelineMixin, SupportImageInput, DiffusionPipelineProfilerMixin): +class FluxKontextPipeline( + nn.Module, FluxPipelineMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin +): """FLUX.1-Kontext pipeline for image editing with text guidance.""" support_image_input = True @@ -640,58 +643,61 @@ def forward( # 5. Denoising loop self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - neg_noise_pred = self.transformer( + with self.progress_bar(total=len(timesteps)) as pbar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + pbar.update() if output_type == "latent": image = latents else: diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 4f8011f2283..cc25c6b7043 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -29,6 +29,7 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs @@ -332,7 +333,7 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator = raise AttributeError("Could not access latents of provided encoder_output") -class Flux2Pipeline(nn.Module, SupportImageInput, DiffusionPipelineProfilerMixin): +class Flux2Pipeline(nn.Module, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin): """Flux2 pipeline for text-to-image generation.""" _callback_tensor_inputs = ["latents", "prompt_embeds"] @@ -1032,48 +1033,51 @@ def forward( # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - latent_model_input = latents.to(self.transformer.dtype) - latent_image_ids = latent_ids - - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) - latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - - noise_pred = self.transformer( - hidden_states=latent_model_input, # (B, image_seq_len, C) - timestep=timestep / 1000, - guidance=guidance_tensor, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, # B, text_seq_len, 4 - img_ids=latent_image_ids, # B, image_seq_len, 4 - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred[:, : latents.size(1) :] - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype and torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + with self.progress_bar(total=len(timesteps)) as pbar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance_tensor, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + pbar.update() self._current_timestep = None diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py index c2c32752cb6..6445bfee215 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py @@ -24,6 +24,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.hunyuan_video.hunyuan_video_15_transformer import HunyuanVideo15Transformer3DModel +from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -82,7 +83,7 @@ def post_process_func(video: torch.Tensor, output_type: str = "pil"): return post_process_func -class HunyuanVideo15Pipeline(nn.Module, CFGParallelMixin, DiffusionPipelineProfilerMixin): +class HunyuanVideo15Pipeline(nn.Module, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin): def __init__( self, *, @@ -450,60 +451,63 @@ def forward( timesteps = self.scheduler.timesteps self._num_timesteps = len(timesteps) - for i, t in enumerate(timesteps): - self._current_timestep = t - - latent_model_input = torch.cat([latents, cond_latents, mask], dim=1) - timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) - - timestep_r = None - if self.use_meanflow: - if i == len(timesteps) - 1: - timestep_r = torch.tensor([0.0], device=device) - else: - timestep_r = timesteps[i + 1] - timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "timestep_r": timestep_r, - "encoder_hidden_states": prompt_embeds, - "encoder_attention_mask": prompt_embeds_mask, - "encoder_hidden_states_2": prompt_embeds_2, - "encoder_attention_mask_2": prompt_embeds_mask_2, - "image_embeds": image_embeds, - "return_dict": False, - } - - negative_kwargs = None - if do_cfg and negative_prompt_embeds is not None: - negative_kwargs = { + with self.progress_bar(total=len(timesteps)) as pbar: + for i, t in enumerate(timesteps): + self._current_timestep = t + + latent_model_input = torch.cat([latents, cond_latents, mask], dim=1) + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + timestep_r = None + if self.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + + positive_kwargs = { "hidden_states": latent_model_input, "timestep": timestep, "timestep_r": timestep_r, - "encoder_hidden_states": negative_prompt_embeds, - "encoder_attention_mask": negative_prompt_embeds_mask, - "encoder_hidden_states_2": negative_prompt_embeds_2, - "encoder_attention_mask_2": negative_prompt_embeds_mask_2, + "encoder_hidden_states": prompt_embeds, + "encoder_attention_mask": prompt_embeds_mask, + "encoder_hidden_states_2": prompt_embeds_2, + "encoder_attention_mask_2": prompt_embeds_mask_2, "image_embeds": image_embeds, "return_dict": False, } - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=do_cfg and negative_kwargs is not None, - true_cfg_scale=guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=req.sampling_params.cfg_normalize, - ) + negative_kwargs = None + if do_cfg and negative_prompt_embeds is not None: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "timestep_r": timestep_r, + "encoder_hidden_states": negative_prompt_embeds, + "encoder_attention_mask": negative_prompt_embeds_mask, + "encoder_hidden_states_2": negative_prompt_embeds_2, + "encoder_attention_mask_2": negative_prompt_embeds_mask_2, + "image_embeds": image_embeds, + "return_dict": False, + } + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_cfg and negative_kwargs is not None, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=req.sampling_params.cfg_normalize, + ) - latents = self.scheduler_step_maybe_with_cfg( - noise_pred, - t, - latents, - do_true_cfg=do_cfg and negative_kwargs is not None, - ) + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, + t, + latents, + do_true_cfg=do_cfg and negative_kwargs is not None, + ) + + pbar.update() self._current_timestep = None diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py index 022641394e2..c1acd1a895a 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py @@ -38,6 +38,7 @@ retrieve_latents, ) from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -99,7 +100,9 @@ def pre_process_func(req: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class HunyuanVideo15I2VPipeline(nn.Module, CFGParallelMixin, SupportImageInput, DiffusionPipelineProfilerMixin): +class HunyuanVideo15I2VPipeline( + nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin +): support_image_input = True color_format = "RGB" @@ -525,61 +528,64 @@ def forward( timesteps = self.scheduler.timesteps self._num_timesteps = len(timesteps) - for i, t in enumerate(timesteps): - self._current_timestep = t - - latent_model_input = torch.cat([latents, cond_latents, mask], dim=1) - timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) - - timestep_r = None - if self.use_meanflow: - if i == len(timesteps) - 1: - timestep_r = torch.tensor([0.0], device=device) - else: - timestep_r = timesteps[i + 1] - timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "timestep_r": timestep_r, - "encoder_hidden_states": prompt_embeds, - "encoder_attention_mask": prompt_embeds_mask, - "encoder_hidden_states_2": prompt_embeds_2, - "encoder_attention_mask_2": prompt_embeds_mask_2, - "image_embeds": image_embeds, - "return_dict": False, - } - - negative_kwargs = None - if do_cfg and negative_prompt_embeds is not None: - # For I2V CFG, negative still uses image embeds (only text is unconditional) - negative_kwargs = { + with self.progress_bar(total=len(timesteps)) as pbar: + for i, t in enumerate(timesteps): + self._current_timestep = t + + latent_model_input = torch.cat([latents, cond_latents, mask], dim=1) + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + timestep_r = None + if self.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + + positive_kwargs = { "hidden_states": latent_model_input, "timestep": timestep, "timestep_r": timestep_r, - "encoder_hidden_states": negative_prompt_embeds, - "encoder_attention_mask": negative_prompt_embeds_mask, - "encoder_hidden_states_2": negative_prompt_embeds_2, - "encoder_attention_mask_2": negative_prompt_embeds_mask_2, + "encoder_hidden_states": prompt_embeds, + "encoder_attention_mask": prompt_embeds_mask, + "encoder_hidden_states_2": prompt_embeds_2, + "encoder_attention_mask_2": prompt_embeds_mask_2, "image_embeds": image_embeds, "return_dict": False, } - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=do_cfg and negative_kwargs is not None, - true_cfg_scale=guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=req.sampling_params.cfg_normalize, - ) + negative_kwargs = None + if do_cfg and negative_prompt_embeds is not None: + # For I2V CFG, negative still uses image embeds (only text is unconditional) + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "timestep_r": timestep_r, + "encoder_hidden_states": negative_prompt_embeds, + "encoder_attention_mask": negative_prompt_embeds_mask, + "encoder_hidden_states_2": negative_prompt_embeds_2, + "encoder_attention_mask_2": negative_prompt_embeds_mask_2, + "image_embeds": image_embeds, + "return_dict": False, + } + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_cfg and negative_kwargs is not None, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=req.sampling_params.cfg_normalize, + ) - latents = self.scheduler_step_maybe_with_cfg( - noise_pred, - t, - latents, - do_true_cfg=do_cfg and negative_kwargs is not None, - ) + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, + t, + latents, + do_true_cfg=do_cfg and negative_kwargs is not None, + ) + + pbar.update() self._current_timestep = None